Skip to content

Commit

Permalink
add Seq and SeqLEK (go 1.23 iterators)
Browse files Browse the repository at this point in the history
  • Loading branch information
guregu committed Aug 22, 2024
1 parent eed9493 commit 46a3891
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 0 deletions.
6 changes: 6 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,12 @@ func (itr *queryIter) Err() error {
return itr.err
}

func (itr *queryIter) SetError(err error) {
if itr.err == nil {
itr.err = err
}
}

func (itr *queryIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) {
if itr.output != nil {
// if we've hit the end of our results, we can use the real LEK
Expand Down
6 changes: 6 additions & 0 deletions scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,12 @@ func (itr *scanIter) Err() error {
return itr.err
}

func (itr *scanIter) SetError(err error) {
if itr.err == nil {
itr.err = err
}
}

// LastEvaluatedKey returns a key that can be used to continue this scan.
// Use with SearchLimit for best results.
func (itr *scanIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) {
Expand Down
40 changes: 40 additions & 0 deletions seq_go123.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//go:build go1.23

package dynamo

import (
"context"
"iter"
)

// Seq returns an item iterator compatible with Go 1.23 `for ... range` loops.
func Seq[V any](ctx context.Context, iter Iter) iter.Seq[V] {
return func(yield func(V) bool) {
item := new(V)
for iter.Next(ctx, item) {
if !yield(*item) {
break
}
item = new(V)
}
}
}

// SeqLEK returns a LastEvaluatedKey and item iterator compatible with Go 1.23 `for ... range` loops.
func SeqLEK[V any](ctx context.Context, iter PagingIter) iter.Seq2[PagingKey, V] {
return func(yield func(PagingKey, V) bool) {
item := new(V)
for iter.Next(ctx, item) {
lek, err := iter.LastEvaluatedKey(ctx)
if err != nil {
if setter, ok := iter.(interface{ SetError(error) }); ok {
setter.SetError(err)
}
}
if !yield(lek, *item) {
break
}
item = new(V)
}
}
}
75 changes: 75 additions & 0 deletions seq_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//go:build go1.23

package dynamo

import (
"context"
"testing"
"time"
)

func TestSeq(t *testing.T) {
if testDB == nil {
t.Skip(offlineSkipMsg)
}
ctx := context.Background()
table := testDB.Table(testTableWidgets)

widgets := []any{
widget{
UserID: 1971,
Time: time.Date(1971, 4, 00, 0, 0, 0, 0, time.UTC),
Msg: "Seq1",
},
widget{
UserID: 1971,
Time: time.Date(1971, 4, 10, 0, 0, 0, 0, time.UTC),
Msg: "Seq1",
},
widget{
UserID: 1971,
Time: time.Date(1971, 4, 20, 0, 0, 0, 0, time.UTC),
Msg: "Seq1",
},
}

t.Run("prepare data", func(t *testing.T) {
if _, err := table.Batch().Write().Put(widgets...).Run(ctx); err != nil {
t.Fatal(err)
}
})

iter := testDB.Table(testTableWidgets).Get("UserID", 1971).Iter()
var got []*widget
var count int
for item := range Seq[*widget](ctx, iter) {
t.Log(item)
item.Count = count
got = append(got, item)
count++
}

if iter.Err() != nil {
t.Fatal(iter.Err())
}

t.Run("results match", func(t *testing.T) {
for i, item := range got {
want := widgets[i].(widget)
if !item.Time.Equal(want.Time) {
t.Error("bad result. want:", want.Time, "got:", item.Time)
}
}
})

t.Run("result item isolation", func(t *testing.T) {
// make sure that when mutating the result in the `for ... range` loop
// it only affects one item
t.Log("got", got)
for i, item := range got {
if item.Count != i {
t.Error("unexpected count. got:", item.Count, "want:", i)
}
}
})
}

0 comments on commit 46a3891

Please sign in to comment.