Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F refactor iterators #755

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions platform/common/core/generic/vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,18 +349,16 @@ func (db *Vault[V]) Statuses(ctx context.Context, txIDs ...driver.TxID) ([]drive
if err != nil {
return nil, err
}
statuses := make([]driver.TxValidationStatus[V], 0, len(txIDs))
for status, err := it.Next(); status != nil; status, err = it.Next() {
if err != nil {
return nil, err
return collections.ReadAll(collections.Map(it, func(status *driver.TxStatus) (*driver.TxValidationStatus[V], error) {
if status == nil {
return nil, nil
}
statuses = append(statuses, driver.TxValidationStatus[V]{
return &driver.TxValidationStatus[V]{
TxID: status.TxID,
ValidationCode: db.vcProvider.FromInt32(status.Code),
Message: status.Message,
})
}
return statuses, nil
}, nil
}))
}

func (db *Vault[V]) SetStatus(ctx context.Context, txID driver.TxID, code V) error {
Expand Down
9 changes: 5 additions & 4 deletions platform/common/utils/cache/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ package cache
// NewMapCache creates a dummy implementation of the Cache interface.
// It is backed by a map with unlimited capacity.
func NewMapCache[K comparable, V any]() *mapCache[K, V] {
return &mapCache[K, V]{
m: map[K]V{},
l: &noLock{},
}
return NewMapCacheWithLock[K, V](&noLock{})
}

func NewMapCacheWithLock[K comparable, V any](l rwLock) *mapCache[K, V] {
return &mapCache[K, V]{m: map[K]V{}, l: l}
}

type mapCache[K comparable, V any] struct {
Expand Down
38 changes: 36 additions & 2 deletions platform/common/utils/collections/iterators.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"math/rand"

"github.com/hyperledger-labs/fabric-smart-client/platform/common/utils"
"github.com/hyperledger-labs/fabric-smart-client/platform/common/utils/functions"
)

type baseIterator[k any] interface {
Expand Down Expand Up @@ -68,6 +69,18 @@ func ReadAll[T any](it Iterator[*T]) ([]T, error) {
return items, nil
}

func ToSlice[T any](it Iterator[*T]) ([]*T, error) {
defer it.Close()
items := make([]*T, 0)
for item, err := it.Next(); item != nil || err != nil; item, err = it.Next() {
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, nil
}

func NewSingleIterator[T any](item T) *sliceIterator[T] {
return NewSliceIterator[T]([]T{item})
}
Expand Down Expand Up @@ -123,13 +136,34 @@ func (it *permutationIterator[T]) Close() {
it.perm = nil
}

func Map[A any, B any](iterator Iterator[A], transformer func(A) (B, error)) Iterator[B] {
func Filter[A any](iterator Iterator[A], filter functions.Filter[A]) Iterator[A] {
return &filteredIterator[A]{Iterator: iterator, filter: filter}
}

type filteredIterator[A any] struct {
Iterator[A]
filter func(A) bool
}

func (it *filteredIterator[A]) Next() (A, error) {
if next, err := it.Iterator.Next(); err != nil {
return next, err
} else if utils.IsNil(next) {
return utils.Zero[A](), nil
} else if !it.filter(next) {
return it.Next()
} else {
return next, nil
}
}

func Map[A any, B any](iterator Iterator[A], transformer functions.Mapper[A, B]) Iterator[B] {
return &mappedIterator[A, B]{Iterator: iterator, transformer: transformer}
}

type mappedIterator[A any, B any] struct {
Iterator[A]
transformer func(A) (B, error)
transformer functions.Mapper[A, B]
}

func (it *mappedIterator[A, B]) Next() (B, error) {
Expand Down
15 changes: 15 additions & 0 deletions platform/common/utils/functions/functions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
Copyright IBM Corp. All Rights Reserved.

SPDX-License-Identifier: Apache-2.0
*/

package functions

type Mapper[A any, B any] func(A) (B, error)

type Filter[T any] func(T) bool

func Not[T any](f Filter[T]) Filter[T] {
return func(t T) bool { return !f(t) }
}
20 changes: 20 additions & 0 deletions platform/common/utils/nulls.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ SPDX-License-Identifier: Apache-2.0

package utils

import "reflect"

func Zero[A any]() A {
var a A
return a
Expand Down Expand Up @@ -54,3 +56,21 @@ func DefaultString(a interface{}, b string) string {
}
return b
}

var nullableTypes = map[reflect.Kind]struct{}{
reflect.Ptr: {},
reflect.Map: {},
reflect.Chan: {},
reflect.Func: {},
reflect.Slice: {},
reflect.Interface: {},
}

func isNullable(k reflect.Kind) bool {
_, ok := nullableTypes[k]
return ok
}

func IsNil(a any) bool {
return a == nil || isNullable(reflect.TypeOf(a).Kind()) && reflect.ValueOf(a).IsNil()
}
12 changes: 5 additions & 7 deletions platform/view/services/db/driver/sql/common/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"database/sql"
"fmt"

"github.com/hyperledger-labs/fabric-smart-client/platform/common/utils/collections"
"github.com/hyperledger-labs/fabric-smart-client/platform/view/services/db/driver"
"github.com/hyperledger-labs/fabric-smart-client/platform/view/view"
"github.com/pkg/errors"
Expand Down Expand Up @@ -53,13 +54,10 @@ func (db *BindingPersistence) HaveSameBinding(this, that view.Identity) (bool, e
}
defer rows.Close()

longTermIds := make([]view.Identity, 0, 2)
for rows.Next() {
var longTerm view.Identity
if err := rows.Scan(&longTerm); err != nil {
return false, err
}
longTermIds = append(longTermIds, longTerm)
it := QueryIterator(rows, func(r RowScanner, id *view.Identity) error { return r.Scan(id) })
longTermIds, err := collections.ReadAll(it)
if err != nil {
return false, err
}
if len(longTermIds) != 2 {
return false, errors.Errorf("%d entries found instead of 2", len(longTermIds))
Expand Down
46 changes: 16 additions & 30 deletions platform/view/services/db/driver/sql/common/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,13 +500,10 @@ func TTestCompositeKeys(t *testing.T, db driver.UnversionedPersistence) {

itr, err := db.GetStateRangeScanIterator(ns, startKey, endKey)
assert.NoError(t, err)
defer itr.Close()

res := make([]driver.UnversionedRead, 0, 4)
for n, err := itr.Next(); n != nil; n, err = itr.Next() {
assert.NoError(t, err)
res = append(res, *n)
}
res, err := collections.ReadAll(itr)
assert.NoError(t, err)

assert.Len(t, res, 4)
assert.Equal(t, []driver.UnversionedRead{
{Key: "\x00prefix0a0b0", Raw: []uint8{0x0, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x30, 0x61, 0x30, 0x62, 0x30}},
Expand All @@ -522,13 +519,10 @@ func TTestCompositeKeys(t *testing.T, db driver.UnversionedPersistence) {

itr, err = db.GetStateRangeScanIterator(ns, startKey, endKey)
assert.NoError(t, err)
defer itr.Close()

res = make([]driver.UnversionedRead, 0, 2)
for n, err := itr.Next(); n != nil; n, err = itr.Next() {
assert.NoError(t, err)
res = append(res, *n)
}
res, err = collections.ReadAll(itr)
assert.NoError(t, err)

assert.Len(t, res, 3)
assert.Equal(t, []driver.UnversionedRead{
{Key: "\x00prefix0a0b0", Raw: []uint8{0x0, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x30, 0x61, 0x30, 0x62, 0x30}},
Expand Down Expand Up @@ -623,13 +617,10 @@ func TTestUnversionedRange(t *testing.T, db driver.UnversionedPersistence) {

itr, err := db.GetStateRangeScanIterator(ns, "", "")
assert.NoError(t, err)
defer itr.Close()

res := make([]driver.UnversionedRead, 0, 4)
for n, err := itr.Next(); n != nil; n, err = itr.Next() {
assert.NoError(t, err)
res = append(res, *n)
}
res, err := collections.ReadAll(itr)
assert.NoError(t, err)

assert.Len(t, res, 4)
assert.Equal(t, []driver.UnversionedRead{
{Key: "k1", Raw: []byte("k1_value")},
Expand All @@ -640,13 +631,10 @@ func TTestUnversionedRange(t *testing.T, db driver.UnversionedPersistence) {

itr, err = db.GetStateRangeScanIterator(ns, "k1", "k3")
assert.NoError(t, err)
defer itr.Close()

res = make([]driver.UnversionedRead, 0, 3)
for n, err := itr.Next(); n != nil; n, err = itr.Next() {
assert.NoError(t, err)
res = append(res, *n)
}
res, err = collections.ReadAll(itr)
assert.NoError(t, err)

assert.Len(t, res, 3)
assert.Equal(t, []driver.UnversionedRead{
{Key: "k1", Raw: []byte("k1_value")},
Expand All @@ -656,12 +644,10 @@ func TTestUnversionedRange(t *testing.T, db driver.UnversionedPersistence) {

itr, err = db.GetStateSetIterator(ns, "k1", "k2")
assert.NoError(t, err)
defer itr.Close()
res = make([]driver.UnversionedRead, 0, 2)
for n, err := itr.Next(); n != nil; n, err = itr.Next() {
assert.NoError(t, err)
res = append(res, *n)
}

res, err = collections.ReadAll(itr)
assert.NoError(t, err)

assert.Len(t, res, 2)
expected := []driver.UnversionedRead{
{Key: "k1", Raw: []byte("k1_value")},
Expand Down
20 changes: 12 additions & 8 deletions platform/view/services/db/driver/sql/common/signerinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"database/sql"
"fmt"

"github.com/hyperledger-labs/fabric-smart-client/platform/common/utils/collections"
"github.com/hyperledger-labs/fabric-smart-client/platform/view/services/db/driver"
"github.com/hyperledger-labs/fabric-smart-client/platform/view/view"
"github.com/pkg/errors"
Expand All @@ -35,11 +36,11 @@ type SignerInfoPersistence struct {

func (db *SignerInfoPersistence) FilterExistingSigners(ids ...view.Identity) ([]view.Identity, error) {
idHashes := make([]string, len(ids))
inverseMap := make(map[string]view.Identity, len(ids))
inverseMap := make(map[string]*view.Identity, len(ids))
for i, id := range ids {
idHash := id.UniqueID()
idHashes[i] = idHash
inverseMap[idHash] = id
inverseMap[idHash] = &id
}
where, params := Where(db.ci.InStrings("id", idHashes))
query := fmt.Sprintf("SELECT id FROM %s %s", db.table, where)
Expand All @@ -51,13 +52,16 @@ func (db *SignerInfoPersistence) FilterExistingSigners(ids ...view.Identity) ([]
}
defer rows.Close()

existingSigners := make([]view.Identity, 0)
for rows.Next() {
var idHash string
if err := rows.Scan(&idHash); err != nil {
return nil, errors.Wrapf(err, "failed scanning row")
idHashItr := QueryIterator(rows, func(r RowScanner, h *string) error { return r.Scan(h) })
idItr := collections.Map(idHashItr, func(h *string) (*view.Identity, error) {
if h == nil {
return nil, nil
}
existingSigners = append(existingSigners, inverseMap[idHash])
return inverseMap[*h], nil
})
existingSigners, err := collections.ReadAll[view.Identity](idItr)
if err != nil {
return nil, err
}
logger.Debugf("Found %d out of %d signers", len(existingSigners), len(ids))
return existingSigners, nil
Expand Down
21 changes: 2 additions & 19 deletions platform/view/services/db/driver/sql/common/unversioned.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (db *UnversionedPersistence) GetStateRangeScanIterator(ns driver2.Namespace
return nil, errors2.Wrapf(err, "query error: %s", query)
}

return &readIterator{txs: rows}, nil
return QueryIterator(rows, func(r RowScanner, read *driver.UnversionedRead) error { return r.Scan(&read.Key, &read.Raw) }), nil
}

func (db *UnversionedPersistence) GetState(namespace driver2.Namespace, key driver2.PKey) (driver.UnversionedValue, error) {
Expand All @@ -87,7 +87,7 @@ func (db *UnversionedPersistence) GetStateSetIterator(ns driver2.Namespace, keys
return nil, errors2.Wrapf(err, "query error: %s", query)
}

return &readIterator{txs: rows}, nil
return QueryIterator(rows, func(r RowScanner, read *driver.UnversionedRead) error { return r.Scan(&read.Key, &read.Raw) }), nil
}

func (db *UnversionedPersistence) Close() error {
Expand Down Expand Up @@ -231,23 +231,6 @@ func (db *UnversionedPersistence) Stats() any {
return nil
}

type readIterator struct {
txs *sql.Rows
}

func (t *readIterator) Close() {
t.txs.Close()
}

func (t *readIterator) Next() (*driver.UnversionedRead, error) {
if !t.txs.Next() {
return nil, nil
}
var r driver.UnversionedRead
err := t.txs.Scan(&r.Key, &r.Raw)
return &r, err
}

func (db *UnversionedPersistence) CreateSchema() error {
return InitSchema(db.writeDB, fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
Expand Down
Loading
Loading