From 46a389198ae8fe65af12743508ef3d12654a8120 Mon Sep 17 00:00:00 2001 From: guregu Date: Thu, 15 Aug 2024 02:39:49 +0900 Subject: [PATCH 1/2] add Seq and SeqLEK (go 1.23 iterators) --- query.go | 6 +++++ scan.go | 6 +++++ seq_go123.go | 40 ++++++++++++++++++++++++++++ seq_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 127 insertions(+) create mode 100644 seq_go123.go create mode 100644 seq_test.go diff --git a/query.go b/query.go index ed82375..f5f7a73 100644 --- a/query.go +++ b/query.go @@ -428,6 +428,12 @@ func (itr *queryIter) Err() error { return itr.err } +func (itr *queryIter) SetError(err error) { + if itr.err == nil { + itr.err = err + } +} + func (itr *queryIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) { if itr.output != nil { // if we've hit the end of our results, we can use the real LEK diff --git a/scan.go b/scan.go index 26792cf..2e8222c 100644 --- a/scan.go +++ b/scan.go @@ -434,6 +434,12 @@ func (itr *scanIter) Err() error { return itr.err } +func (itr *scanIter) SetError(err error) { + if itr.err == nil { + itr.err = err + } +} + // LastEvaluatedKey returns a key that can be used to continue this scan. // Use with SearchLimit for best results. func (itr *scanIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) { diff --git a/seq_go123.go b/seq_go123.go new file mode 100644 index 0000000..9f6c897 --- /dev/null +++ b/seq_go123.go @@ -0,0 +1,40 @@ +//go:build go1.23 + +package dynamo + +import ( + "context" + "iter" +) + +// Seq returns an item iterator compatible with Go 1.23 `for ... range` loops. +func Seq[V any](ctx context.Context, iter Iter) iter.Seq[V] { + return func(yield func(V) bool) { + item := new(V) + for iter.Next(ctx, item) { + if !yield(*item) { + break + } + item = new(V) + } + } +} + +// SeqLEK returns a LastEvaluatedKey and item iterator compatible with Go 1.23 `for ... range` loops. +func SeqLEK[V any](ctx context.Context, iter PagingIter) iter.Seq2[PagingKey, V] { + return func(yield func(PagingKey, V) bool) { + item := new(V) + for iter.Next(ctx, item) { + lek, err := iter.LastEvaluatedKey(ctx) + if err != nil { + if setter, ok := iter.(interface{ SetError(error) }); ok { + setter.SetError(err) + } + } + if !yield(lek, *item) { + break + } + item = new(V) + } + } +} diff --git a/seq_test.go b/seq_test.go new file mode 100644 index 0000000..49d0006 --- /dev/null +++ b/seq_test.go @@ -0,0 +1,75 @@ +//go:build go1.23 + +package dynamo + +import ( + "context" + "testing" + "time" +) + +func TestSeq(t *testing.T) { + if testDB == nil { + t.Skip(offlineSkipMsg) + } + ctx := context.Background() + table := testDB.Table(testTableWidgets) + + widgets := []any{ + widget{ + UserID: 1971, + Time: time.Date(1971, 4, 00, 0, 0, 0, 0, time.UTC), + Msg: "Seq1", + }, + widget{ + UserID: 1971, + Time: time.Date(1971, 4, 10, 0, 0, 0, 0, time.UTC), + Msg: "Seq1", + }, + widget{ + UserID: 1971, + Time: time.Date(1971, 4, 20, 0, 0, 0, 0, time.UTC), + Msg: "Seq1", + }, + } + + t.Run("prepare data", func(t *testing.T) { + if _, err := table.Batch().Write().Put(widgets...).Run(ctx); err != nil { + t.Fatal(err) + } + }) + + iter := testDB.Table(testTableWidgets).Get("UserID", 1971).Iter() + var got []*widget + var count int + for item := range Seq[*widget](ctx, iter) { + t.Log(item) + item.Count = count + got = append(got, item) + count++ + } + + if iter.Err() != nil { + t.Fatal(iter.Err()) + } + + t.Run("results match", func(t *testing.T) { + for i, item := range got { + want := widgets[i].(widget) + if !item.Time.Equal(want.Time) { + t.Error("bad result. want:", want.Time, "got:", item.Time) + } + } + }) + + t.Run("result item isolation", func(t *testing.T) { + // make sure that when mutating the result in the `for ... range` loop + // it only affects one item + t.Log("got", got) + for i, item := range got { + if item.Count != i { + t.Error("unexpected count. got:", item.Count, "want:", i) + } + } + }) +} From 38c18671e27dd71ef9acbfe24550df779b2e8890 Mon Sep 17 00:00:00 2001 From: guregu Date: Fri, 23 Aug 2024 03:36:38 +0900 Subject: [PATCH 2/2] add ItemTableIter --- batchget.go | 1 + batchget_go123.go | 83 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 batchget_go123.go diff --git a/batchget.go b/batchget.go index cfb0c3e..698d335 100644 --- a/batchget.go +++ b/batchget.go @@ -193,6 +193,7 @@ func (bg *BatchGet) Iter() Iter { // IterWithTable is like [BatchGet.Iter], but will update the value pointed by tablePtr after each iteration. // This can be useful when getting from multiple tables to determine which table the latest item came from. +// See: [BatchGet.ItemTableIter] for a nicer way to do this. // // For example, you can utilize this iterator to read the results into different structs. // diff --git a/batchget_go123.go b/batchget_go123.go new file mode 100644 index 0000000..835a059 --- /dev/null +++ b/batchget_go123.go @@ -0,0 +1,83 @@ +//go:build go1.23 + +package dynamo + +import ( + "context" + "iter" +) + +type ItemTableIter[V any] interface { + // Items is a sequence of item and table names. + // This is a single use iterator. + // Be sure to check for errors with Err afterwards. + Items(context.Context) iter.Seq2[V, string] + // Err must be checked after iterating. + Err() error +} + +// ItemTableIter returns an iterator of (raw item, table name). +// To specify a type, use [BatchGetIter] instead. +// +// For example, you can utilize this iterator to read the results into different structs. +// +// widgetBatch := widgetsTable.Batch("UserID").Get(dynamo.Keys{userID}) +// sprocketBatch := sprocketsTable.Batch("UserID").Get(dynamo.Keys{userID}) +// +// iter := widgetBatch.Merge(sprocketBatch).ItemTableIter(&table) +// +// // now we will use the table iterator to unmarshal the values into their respective types +// var s sprocket +// var w widget +// for raw, table := range iter.Items { +// if table == "Widgets" { +// err := dynamo.UnmarshalItem(raw, &w) +// if err != nil { +// fmt.Println(err) +// } +// } else if table == "Sprockets" { +// err := dynamo.UnmarshalItem(raw, &s) +// if err != nil { +// fmt.Println(err) +// } +// } else { +// fmt.Printf("Unexpected Table: %s\n", table) +// } +// } +// +// if iter.Err() != nil { +// fmt.Println(iter.Err()) +// } +func (bg *BatchGet) ItemTableIter() ItemTableIter[Item] { + return newBgIter2[Item](bg) +} + +type bgIter2[V any] struct { + Iter + table string +} + +func newBgIter2[V any](bg *BatchGet) *bgIter2[V] { + iter := new(bgIter2[V]) + iter.Iter = bg.IterWithTable(&iter.table) + return iter +} + +// Items is a sequence of item and table names. +// This is a single use iterator. +// Be sure to check for errors with Err afterwards. +func (iter *bgIter2[V]) Items(ctx context.Context) iter.Seq2[V, string] { + return func(yield func(V, string) bool) { + item := new(V) + for iter.Next(ctx, item) { + if !yield(*item, iter.table) { + break + } + item = new(V) + } + } +} + +func BatchGetIter[V any](bg *BatchGet) ItemTableIter[V] { + return newBgIter2[V](bg) +}