Skip to content

Commit

Permalink
Reused query iterator and added utilities
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Filios <[email protected]>
  • Loading branch information
alexandrosfilios committed Feb 14, 2025
1 parent 90e0fbb commit be2854b
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 127 deletions.
14 changes: 6 additions & 8 deletions platform/common/core/generic/vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,18 +349,16 @@ func (db *Vault[V]) Statuses(ctx context.Context, txIDs ...driver.TxID) ([]drive
if err != nil {
return nil, err
}
statuses := make([]driver.TxValidationStatus[V], 0, len(txIDs))
for status, err := it.Next(); status != nil; status, err = it.Next() {
if err != nil {
return nil, err
return collections.ReadAll(collections.Map(it, func(status *driver.TxStatus) (*driver.TxValidationStatus[V], error) {
if status == nil {
return nil, nil
}
statuses = append(statuses, driver.TxValidationStatus[V]{
return &driver.TxValidationStatus[V]{
TxID: status.TxID,
ValidationCode: db.vcProvider.FromInt32(status.Code),
Message: status.Message,
})
}
return statuses, nil
}, nil
}))
}

func (db *Vault[V]) SetStatus(ctx context.Context, txID driver.TxID, code V) error {
Expand Down
9 changes: 5 additions & 4 deletions platform/common/utils/cache/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ package cache
// NewMapCache creates a dummy implementation of the Cache interface.
// It is backed by a map with unlimited capacity.
func NewMapCache[K comparable, V any]() *mapCache[K, V] {
return &mapCache[K, V]{
m: map[K]V{},
l: &noLock{},
}
return NewMapCacheWithLock[K, V](&noLock{})
}

func NewMapCacheWithLock[K comparable, V any](l rwLock) *mapCache[K, V] {
return &mapCache[K, V]{m: map[K]V{}, l: l}
}

type mapCache[K comparable, V any] struct {
Expand Down
38 changes: 36 additions & 2 deletions platform/common/utils/collections/iterators.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"math/rand"

"github.com/hyperledger-labs/fabric-smart-client/platform/common/utils"
"github.com/hyperledger-labs/fabric-smart-client/platform/common/utils/functions"
)

type baseIterator[k any] interface {
Expand Down Expand Up @@ -68,6 +69,18 @@ func ReadAll[T any](it Iterator[*T]) ([]T, error) {
return items, nil
}

func ToSlice[T any](it Iterator[*T]) ([]*T, error) {
defer it.Close()
items := make([]*T, 0)
for item, err := it.Next(); item != nil || err != nil; item, err = it.Next() {
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, nil
}

func NewSingleIterator[T any](item T) *sliceIterator[T] {
return NewSliceIterator[T]([]T{item})
}
Expand Down Expand Up @@ -123,13 +136,34 @@ func (it *permutationIterator[T]) Close() {
it.perm = nil
}

func Map[A any, B any](iterator Iterator[A], transformer func(A) (B, error)) Iterator[B] {
func Filter[A any](iterator Iterator[A], filter functions.Filter[A]) Iterator[A] {
return &filteredIterator[A]{Iterator: iterator, filter: filter}
}

type filteredIterator[A any] struct {
Iterator[A]
filter func(A) bool
}

func (it *filteredIterator[A]) Next() (A, error) {
if next, err := it.Iterator.Next(); err != nil {
return next, err
} else if utils.IsNil(next) {
return utils.Zero[A](), nil
} else if !it.filter(next) {
return it.Next()
} else {
return next, nil
}
}

func Map[A any, B any](iterator Iterator[A], transformer functions.Mapper[A, B]) Iterator[B] {
return &mappedIterator[A, B]{Iterator: iterator, transformer: transformer}
}

type mappedIterator[A any, B any] struct {
Iterator[A]
transformer func(A) (B, error)
transformer functions.Mapper[A, B]
}

func (it *mappedIterator[A, B]) Next() (B, error) {
Expand Down
15 changes: 15 additions & 0 deletions platform/common/utils/functions/functions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
Copyright IBM Corp. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/

package functions

type Mapper[A any, B any] func(A) (B, error)

type Filter[T any] func(T) bool

func Not[T any](f Filter[T]) Filter[T] {
return func(t T) bool { return !f(t) }
}
20 changes: 20 additions & 0 deletions platform/common/utils/nulls.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ SPDX-License-Identifier: Apache-2.0

package utils

import "reflect"

func Zero[A any]() A {
var a A
return a
Expand Down Expand Up @@ -54,3 +56,21 @@ func DefaultString(a interface{}, b string) string {
}
return b
}

var nullableTypes = map[reflect.Kind]struct{}{
reflect.Ptr: {},
reflect.Map: {},
reflect.Chan: {},
reflect.Func: {},
reflect.Slice: {},
reflect.Interface: {},
}

func isNullable(k reflect.Kind) bool {
_, ok := nullableTypes[k]
return ok
}

func IsNil(a any) bool {
return a == nil || isNullable(reflect.TypeOf(a).Kind()) && reflect.ValueOf(a).IsNil()
}
12 changes: 5 additions & 7 deletions platform/view/services/db/driver/sql/common/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"database/sql"
"fmt"

"github.com/hyperledger-labs/fabric-smart-client/platform/common/utils/collections"
"github.com/hyperledger-labs/fabric-smart-client/platform/view/services/db/driver"
"github.com/hyperledger-labs/fabric-smart-client/platform/view/view"
"github.com/pkg/errors"
Expand Down Expand Up @@ -53,13 +54,10 @@ func (db *BindingPersistence) HaveSameBinding(this, that view.Identity) (bool, e
}
defer rows.Close()

longTermIds := make([]view.Identity, 0, 2)
for rows.Next() {
var longTerm view.Identity
if err := rows.Scan(&longTerm); err != nil {
return false, err
}
longTermIds = append(longTermIds, longTerm)
it := QueryIterator(rows, func(r RowScanner, id *view.Identity) error { return r.Scan(id) })
longTermIds, err := collections.ReadAll(it)
if err != nil {
return false, err
}
if len(longTermIds) != 2 {
return false, errors.Errorf("%d entries found instead of 2", len(longTermIds))
Expand Down
46 changes: 16 additions & 30 deletions platform/view/services/db/driver/sql/common/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,13 +500,10 @@ func TTestCompositeKeys(t *testing.T, db driver.UnversionedPersistence) {

itr, err := db.GetStateRangeScanIterator(ns, startKey, endKey)
assert.NoError(t, err)
defer itr.Close()

res := make([]driver.UnversionedRead, 0, 4)
for n, err := itr.Next(); n != nil; n, err = itr.Next() {
assert.NoError(t, err)
res = append(res, *n)
}
res, err := collections.ReadAll(itr)
assert.NoError(t, err)

assert.Len(t, res, 4)
assert.Equal(t, []driver.UnversionedRead{
{Key: "\x00prefix0a0b0", Raw: []uint8{0x0, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x30, 0x61, 0x30, 0x62, 0x30}},
Expand All @@ -522,13 +519,10 @@ func TTestCompositeKeys(t *testing.T, db driver.UnversionedPersistence) {

itr, err = db.GetStateRangeScanIterator(ns, startKey, endKey)
assert.NoError(t, err)
defer itr.Close()

res = make([]driver.UnversionedRead, 0, 2)
for n, err := itr.Next(); n != nil; n, err = itr.Next() {
assert.NoError(t, err)
res = append(res, *n)
}
res, err = collections.ReadAll(itr)
assert.NoError(t, err)

assert.Len(t, res, 3)
assert.Equal(t, []driver.UnversionedRead{
{Key: "\x00prefix0a0b0", Raw: []uint8{0x0, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x30, 0x61, 0x30, 0x62, 0x30}},
Expand Down Expand Up @@ -623,13 +617,10 @@ func TTestUnversionedRange(t *testing.T, db driver.UnversionedPersistence) {

itr, err := db.GetStateRangeScanIterator(ns, "", "")
assert.NoError(t, err)
defer itr.Close()

res := make([]driver.UnversionedRead, 0, 4)
for n, err := itr.Next(); n != nil; n, err = itr.Next() {
assert.NoError(t, err)
res = append(res, *n)
}
res, err := collections.ReadAll(itr)
assert.NoError(t, err)

assert.Len(t, res, 4)
assert.Equal(t, []driver.UnversionedRead{
{Key: "k1", Raw: []byte("k1_value")},
Expand All @@ -640,13 +631,10 @@ func TTestUnversionedRange(t *testing.T, db driver.UnversionedPersistence) {

itr, err = db.GetStateRangeScanIterator(ns, "k1", "k3")
assert.NoError(t, err)
defer itr.Close()

res = make([]driver.UnversionedRead, 0, 3)
for n, err := itr.Next(); n != nil; n, err = itr.Next() {
assert.NoError(t, err)
res = append(res, *n)
}
res, err = collections.ReadAll(itr)
assert.NoError(t, err)

assert.Len(t, res, 3)
assert.Equal(t, []driver.UnversionedRead{
{Key: "k1", Raw: []byte("k1_value")},
Expand All @@ -656,12 +644,10 @@ func TTestUnversionedRange(t *testing.T, db driver.UnversionedPersistence) {

itr, err = db.GetStateSetIterator(ns, "k1", "k2")
assert.NoError(t, err)
defer itr.Close()
res = make([]driver.UnversionedRead, 0, 2)
for n, err := itr.Next(); n != nil; n, err = itr.Next() {
assert.NoError(t, err)
res = append(res, *n)
}

res, err = collections.ReadAll(itr)
assert.NoError(t, err)

assert.Len(t, res, 2)
expected := []driver.UnversionedRead{
{Key: "k1", Raw: []byte("k1_value")},
Expand Down
20 changes: 12 additions & 8 deletions platform/view/services/db/driver/sql/common/signerinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"database/sql"
"fmt"

"github.com/hyperledger-labs/fabric-smart-client/platform/common/utils/collections"
"github.com/hyperledger-labs/fabric-smart-client/platform/view/services/db/driver"
"github.com/hyperledger-labs/fabric-smart-client/platform/view/view"
"github.com/pkg/errors"
Expand All @@ -35,11 +36,11 @@ type SignerInfoPersistence struct {

func (db *SignerInfoPersistence) FilterExistingSigners(ids ...view.Identity) ([]view.Identity, error) {
idHashes := make([]string, len(ids))
inverseMap := make(map[string]view.Identity, len(ids))
inverseMap := make(map[string]*view.Identity, len(ids))
for i, id := range ids {
idHash := id.UniqueID()
idHashes[i] = idHash
inverseMap[idHash] = id
inverseMap[idHash] = &id
}
where, params := Where(db.ci.InStrings("id", idHashes))
query := fmt.Sprintf("SELECT id FROM %s %s", db.table, where)
Expand All @@ -51,13 +52,16 @@ func (db *SignerInfoPersistence) FilterExistingSigners(ids ...view.Identity) ([]
}
defer rows.Close()

existingSigners := make([]view.Identity, 0)
for rows.Next() {
var idHash string
if err := rows.Scan(&idHash); err != nil {
return nil, errors.Wrapf(err, "failed scanning row")
idHashItr := QueryIterator(rows, func(r RowScanner, h *string) error { return r.Scan(h) })
idItr := collections.Map(idHashItr, func(h *string) (*view.Identity, error) {
if h == nil {
return nil, nil
}
existingSigners = append(existingSigners, inverseMap[idHash])
return inverseMap[*h], nil
})
existingSigners, err := collections.ReadAll[view.Identity](idItr)
if err != nil {
return nil, err
}
logger.Debugf("Found %d out of %d signers", len(existingSigners), len(ids))
return existingSigners, nil
Expand Down
21 changes: 2 additions & 19 deletions platform/view/services/db/driver/sql/common/unversioned.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (db *UnversionedPersistence) GetStateRangeScanIterator(ns driver2.Namespace
return nil, errors2.Wrapf(err, "query error: %s", query)
}

return &readIterator{txs: rows}, nil
return QueryIterator(rows, func(r RowScanner, read *driver.UnversionedRead) error { return r.Scan(&read.Key, &read.Raw) }), nil
}

func (db *UnversionedPersistence) GetState(namespace driver2.Namespace, key driver2.PKey) (driver.UnversionedValue, error) {
Expand All @@ -87,7 +87,7 @@ func (db *UnversionedPersistence) GetStateSetIterator(ns driver2.Namespace, keys
return nil, errors2.Wrapf(err, "query error: %s", query)
}

return &readIterator{txs: rows}, nil
return QueryIterator(rows, func(r RowScanner, read *driver.UnversionedRead) error { return r.Scan(&read.Key, &read.Raw) }), nil
}

func (db *UnversionedPersistence) Close() error {
Expand Down Expand Up @@ -231,23 +231,6 @@ func (db *UnversionedPersistence) Stats() any {
return nil
}

type readIterator struct {
txs *sql.Rows
}

func (t *readIterator) Close() {
t.txs.Close()
}

func (t *readIterator) Next() (*driver.UnversionedRead, error) {
if !t.txs.Next() {
return nil, nil
}
var r driver.UnversionedRead
err := t.txs.Scan(&r.Key, &r.Raw)
return &r, err
}

func (db *UnversionedPersistence) CreateSchema() error {
return InitSchema(db.writeDB, fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
Expand Down
Loading

0 comments on commit be2854b

Please sign in to comment.