Skip to content

Commit

Permalink
Pass correct results to TraceBatchFinishFunc
Browse files Browse the repository at this point in the history
The slice of *Result objects passed to TraceBatchFinishFunc was always
empty. By moving the `defer finish(items)` down below where items is
assigned the results of the batch function, tracers are passed the
actual results.

Fixes #63.
  • Loading branch information
mjq authored and Matt Quinn committed Mar 30, 2021
1 parent c87fdce commit d659c93
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
3 changes: 2 additions & 1 deletion dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,6 @@ func (b *batcher) batch(originalContext context.Context) {
}

ctx, finish := b.tracer.TraceBatch(originalContext, keys)
defer finish(items)

func() {
defer func() {
Expand All @@ -432,6 +431,8 @@ func (b *batcher) batch(originalContext context.Context) {
items = b.batchFn(ctx, keys)
}()

defer finish(items)

if panicErr != nil {
for _, req := range reqs {
req.channel <- &Result{Error: fmt.Errorf("Panic received in batch function: %v", panicErr)}
Expand Down
38 changes: 38 additions & 0 deletions dataloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,26 @@ func TestLoader(t *testing.T) {
}
})

t.Run("tracer's TraceBatch finish func is passed the Result slice", func(t *testing.T) {
t.Parallel()
identityLoader, _ := IDLoader(0)
tracer := new(RecordingTracer)
identityLoader.tracer = tracer
ctx := context.Background()
future := identityLoader.Load(ctx, StringKey("1"))
_, err := future()
if err != nil {
t.Error(err.Error())
}

calls := tracer.traceBatchFinishCalls
inner := []*Result{{Data: "1"}}
expected := [][]*Result{inner}
if !reflect.DeepEqual(calls, expected) {
t.Errorf("tracer did not receive expected results. Expected %#v, got %#v", expected, calls)
}
})

}

// test helpers
Expand Down Expand Up @@ -586,6 +606,24 @@ func FaultyLoader() (*Loader, *[][]string) {
return loader, &loadCalls
}

type RecordingTracer struct {
traceBatchFinishCalls [][]*Result
}

func (t *RecordingTracer) TraceLoad(ctx context.Context, key Key) (context.Context, TraceLoadFinishFunc) {
return ctx, func(Thunk) {}
}

func (t *RecordingTracer) TraceLoadMany(ctx context.Context, keys Keys) (context.Context, TraceLoadManyFinishFunc) {
return ctx, func(ThunkMany) {}
}

func (t *RecordingTracer) TraceBatch(ctx context.Context, keys Keys) (context.Context, TraceBatchFinishFunc) {
return ctx, func(results []*Result) {
t.traceBatchFinishCalls = append(t.traceBatchFinishCalls, results)
}
}

///////////////////////////////////////////////////
// Benchmarks
///////////////////////////////////////////////////
Expand Down

0 comments on commit d659c93

Please sign in to comment.