Skip to content

Commit

Permalink
Support cache skipping for Load() calls that throw SkipCacheError
Browse files Browse the repository at this point in the history
  • Loading branch information
goncalvesnelson committed May 27, 2024
1 parent ab736ad commit bf5efb6
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 1 deletion.
21 changes: 20 additions & 1 deletion dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,24 @@ func (p *PanicErrorWrapper) Error() string {
return p.panicError.Error()
}

// SkipCacheError wraps the error interface.
// The cache should not store SkipCacheErrors.
type SkipCacheError struct {
err error
}

func (s *SkipCacheError) Error() string {
return s.err.Error()
}

func (s *SkipCacheError) Unwrap() error {
return s.err
}

func NewSkipCacheError(err error) *SkipCacheError {
return &SkipCacheError{err: err}
}

// Loader implements the dataloader.Interface.
type Loader[K comparable, V any] struct {
// the batch function to be used by this loader
Expand Down Expand Up @@ -232,7 +250,8 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
result.mu.RLock()
defer result.mu.RUnlock()
var ev *PanicErrorWrapper
if result.value.Error != nil && errors.As(result.value.Error, &ev) {
var es *SkipCacheError
if result.value.Error != nil && (errors.As(result.value.Error, &ev) || errors.As(result.value.Error, &es)){
l.Clear(ctx, key)
}
return result.value.Data, result.value.Error
Expand Down
63 changes: 63 additions & 0 deletions dataloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,45 @@ func TestLoader(t *testing.T) {
}
})

t.Run("test Load Method not caching results with errors of type SkipCacheError", func(t *testing.T) {
t.Parallel()
skipCacheLoader, loadCalls := SkipCacheErrorLoader(3, "1")
ctx := context.Background()
futures1 := skipCacheLoader.LoadMany(ctx, []string{"1", "2", "3"})
_, errs1 := futures1()
var errCount int = 0
var nilCount int = 0
for _, err := range errs1 {
if err == nil {
nilCount++
} else {
errCount++
}
}
if errCount != 1 {
t.Error("Expected an error on only key \"1\"")
}

if nilCount != 2 {
t.Error("Expected the other errors to be nil")
}

futures2 := skipCacheLoader.LoadMany(ctx, []string{"2", "3", "1"})
_, errs2 := futures2()
// There should be no errors in the second batch, as the only key that was not cached
// this time around will not throw an error
if errs2 != nil {
t.Error("Expected LoadMany() to return nil error slice when no errors occurred")
}

calls := (*loadCalls)[1]
expected := []string{"1"}

if !reflect.DeepEqual(calls, expected) {
t.Errorf("Expected load calls %#v, got %#v", expected, calls)
}
})

t.Run("test Load Method Panic Safety in multiple keys", func(t *testing.T) {
t.Parallel()
defer func() {
Expand Down Expand Up @@ -622,6 +661,30 @@ func ErrorCacheLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
return errorCacheLoader, &loadCalls
}

func SkipCacheErrorLoader[K comparable](max int, onceErrorKey K) (*Loader[K, K], *[][]K) {
var mu sync.Mutex
var loadCalls [][]K
errorThrown := false
skipCacheErrorLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] {
var results []*Result[K]
mu.Lock()
loadCalls = append(loadCalls, keys)
mu.Unlock()
// return a non cacheable error for the first occurence of onceErrorKey
for _, k := range keys {
if !errorThrown && k == onceErrorKey {
results = append(results, &Result[K]{k, NewSkipCacheError(fmt.Errorf("non cacheable error"))})
errorThrown = true
} else {
results = append(results, &Result[K]{k, nil})
}
}

return results
}, WithBatchCapacity[K, K](max))
return skipCacheErrorLoader, &loadCalls
}

func BadLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
var mu sync.Mutex
var loadCalls [][]K
Expand Down

0 comments on commit bf5efb6

Please sign in to comment.