Skip to content

Commit

Permalink
Merge pull request #1213 from flyinfish/discussions/1206-order
Browse files Browse the repository at this point in the history
Provide Evaluations in same order as they where submitted
  • Loading branch information
geoand authored Jan 10, 2025
2 parents b980304 + 4c93baf commit e0f1362
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
/**
* Report of the evaluation of a set of samples.
*/
public class EvaluationReport {
public class EvaluationReport<T> {

private final List<Scorer.EvaluationResult<?>> evaluations;
private final List<Scorer.EvaluationResult<T>> evaluations;
private final double score;

/**
* Create a new evaluation report and computes the global score.
*
* @param evaluations the evaluations, must not be {@code null}, must not be empty.
*/
public EvaluationReport(List<Scorer.EvaluationResult<?>> evaluations) {
public EvaluationReport(List<Scorer.EvaluationResult<T>> evaluations) {
this.evaluations = evaluations;
this.score = 100.0 * evaluations.stream().filter(Scorer.EvaluationResult::passed).count() / evaluations.size();
}
Expand All @@ -33,7 +33,7 @@ public double score() {
/**
* @return the evaluations
*/
public List<Scorer.EvaluationResult<?>> evaluations() {
public List<Scorer.EvaluationResult<T>> evaluations() {
return evaluations;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkiverse.langchain4j.testing.scorer;

import java.io.Closeable;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
Expand Down Expand Up @@ -28,50 +29,64 @@ public Scorer() {
}

@SuppressWarnings({ "unchecked" })
public <T> EvaluationReport evaluate(Samples<T> samples, Function<Parameters, T> function,
EvaluationStrategy<T>... strategies) {
List<EvaluationResult<?>> evaluations = new CopyOnWriteArrayList<>();
public <T> EvaluationReport<T> evaluate(
Samples<T> samples, Function<Parameters, T> function, EvaluationStrategy<T>... strategies) {
List<OrderedEvaluationResult<T>> evaluations = new CopyOnWriteArrayList<>();
CountDownLatch latch = new CountDownLatch(samples.size());
var index = 0;
for (EvaluationSample<T> sample : samples) {
// TODO Should we handle the context somehow.
executor.submit(() -> {
try {
var response = execute(sample, function);
LOG.infof("Evaluating sample `%s`", sample.name());
for (EvaluationStrategy<T> strategy : strategies) {
EvaluationResult<T> evaluation = EvaluationResult.fromCompletedEvaluation(sample,
response, strategy.evaluate(sample, response));
LOG.infof("Evaluation of sample `%s` with strategy `%s`: %s", sample.name(),
strategy.getClass().getSimpleName(),
evaluation.passed() ? "OK" : "KO");
evaluations.add(evaluation);
}
} catch (Throwable e) {
LOG.errorf(e, "Failed to evaluate sample `%s`", sample.name());
evaluations.add(EvaluationResult.fromEvaluationThrowable(sample, e));
} finally {
latch.countDown();
}
});
var currentIndex = index++;
executor.submit(
() -> {
try {
var response = execute(sample, function);
LOG.infof("Evaluating sample `%s`", sample.name());
for (EvaluationStrategy<T> strategy : strategies) {
EvaluationResult<T> evaluation = EvaluationResult.fromCompletedEvaluation(
sample, response, strategy.evaluate(sample, response));
LOG.infof(
"Evaluation of sample `%s` with strategy `%s`: %s",
sample.name(),
strategy.getClass().getSimpleName(),
evaluation.passed() ? "OK" : "KO");
evaluations.add(new OrderedEvaluationResult(currentIndex, evaluation));
}
} catch (Throwable e) {
LOG.errorf(e, "Failed to evaluate sample `%s`", sample.name());
evaluations.add(
new OrderedEvaluationResult(
currentIndex, EvaluationResult.fromEvaluationThrowable(sample, e)));
} finally {
latch.countDown();
}
});
}
try {
latch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
return new EvaluationReport(evaluations);
var orderedEvalutions = evaluations.stream()
.sorted(Comparator.comparing(OrderedEvaluationResult::index))
.map(OrderedEvaluationResult::evaluation)
.toList();
return new EvaluationReport<>(orderedEvalutions);
}

public void close() {
executor.shutdown();
}

public record EvaluationResult<T>(EvaluationSample<T> sample, T result, Throwable thrown, boolean passed) {
public static <T> EvaluationResult<T> fromCompletedEvaluation(EvaluationSample<T> sample, T result, boolean passed) {
public record EvaluationResult<T>(
EvaluationSample<T> sample, T result, Throwable thrown, boolean passed) {
public static <T> EvaluationResult<T> fromCompletedEvaluation(
EvaluationSample<T> sample, T result, boolean passed) {
return new EvaluationResult<>(sample, result, null, passed);
}

public static <T> EvaluationResult<T> fromEvaluationThrowable(EvaluationSample<T> sample, Throwable thrown) {
public static <T> EvaluationResult<T> fromEvaluationThrowable(
EvaluationSample<T> sample, Throwable thrown) {
return new EvaluationResult<>(sample, null, thrown, false);
}
}
Expand All @@ -84,4 +99,6 @@ private <T> T execute(EvaluationSample<T> sample, Function<Parameters, T> functi
}
}

private record OrderedEvaluationResult<T>(int index, EvaluationResult<T> evaluation) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import java.util.List;
import java.util.function.Function;
import java.util.stream.Stream;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -40,20 +41,55 @@ void evaluateShouldReturnCorrectReport() {
EvaluationStrategy<String> strategy = (sample, actual) -> actual.equals(sample.expectedOutput());

Samples<String> samples = new Samples<>(sample1, sample2);
EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy);
EvaluationReport<String> report = scorer.evaluate(samples, mockFunction, strategy);

assertThat(report).isNotNull();
assertThat(report.score()).isEqualTo(50.0); // Only one sample should pass.
assertThat(report.evaluations()).hasSize(2);

var actualEvaluations = report.evaluations().stream()
.map(e -> "%s[%s;%s=%s]".formatted(e.sample().name(), e.sample().expectedOutput(), e.result(), e.passed()))
.map(
e -> "%s[%s;%s=%s]"
.formatted(
e.sample().name(), e.sample().expectedOutput(), e.result(), e.passed()))
.toList();
assertThat(actualEvaluations).containsExactlyInAnyOrder(
"Sample1[expected1:param1;expected1:param1=true]",
"Sample2[expected2;expected1:param1=false]");
assertThat(actualEvaluations)
.containsExactly(
"Sample1[expected1:param1;expected1:param1=true]",
"Sample2[expected2;expected1:param1=false]");
}

@SuppressWarnings("unchecked")
@Test
void evaluateShouldReturnCorrectlyOrderedReport() {
scorer = new Scorer(2);
var sleeps = Stream.of(25l, 0l);
var samples = new Samples<>(
sleeps
.map(
sleep -> new EvaluationSample<>(
"%s".formatted(sleep),
new Parameters().add(new Parameter.UnnamedParameter(sleep)),
"irrelevant-for-this-test",
List.of()))
.toList());

var actual = scorer.evaluate(samples, this::sleep, (sample, actualOutput) -> true);

var actualOrder = actual.evaluations().stream().map(e -> e.sample().name()).toList();
assertThat(actualOrder).containsExactly("25", "0");
}

private String sleep(Parameters params) {
long ms = params.get(0);
try {
Thread.sleep(ms);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return "sleeped %s".formatted(ms);
};

@Test
@SuppressWarnings("unchecked")
void evaluateShouldHandleExceptionsInFunction() {
Expand All @@ -71,7 +107,7 @@ void evaluateShouldHandleExceptionsInFunction() {
EvaluationStrategy<String> strategy = (s, actual) -> false;

Samples<String> samples = new Samples<>(sample);
EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy);
EvaluationReport<String> report = scorer.evaluate(samples, mockFunction, strategy);

assertThat(report).isNotNull();
assertThat(report.score()).isEqualTo(0.0); // All evaluations should fail.
Expand All @@ -96,7 +132,7 @@ void evaluateShouldHandleMultipleStrategies() {
EvaluationStrategy<String> strategy2 = (s, actual) -> actual.length() > 3;

Samples<String> samples = new Samples<>(sample);
EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy1, strategy2);
EvaluationReport<String> report = scorer.evaluate(samples, mockFunction, strategy1, strategy2);

assertThat(report).isNotNull();
assertThat(report.score()).isEqualTo(100.0); // Both strategies should pass for the sample.
Expand Down

0 comments on commit e0f1362

Please sign in to comment.