From 70739393de8ca11bbc8b3465a72c935f3bf26ffc Mon Sep 17 00:00:00 2001 From: ilija Date: Wed, 22 Jan 2025 11:48:44 +0100 Subject: [PATCH 1/8] Fix LP codec usage and associated codec cleanup --- pkg/solana/codec/codec_entry.go | 45 ++++--- pkg/solana/codec/codec_test.go | 20 +++- pkg/solana/codec/discriminator.go | 16 ++- pkg/solana/codec/discriminator_test.go | 18 +-- pkg/solana/codec/solana.go | 61 +--------- pkg/solana/codec/testutils/types.go | 150 +++++++++++++++++++++++- pkg/solana/logpoller/discriminator.go | 14 --- pkg/solana/logpoller/filters.go | 49 ++++---- pkg/solana/logpoller/log_poller.go | 15 +-- pkg/solana/logpoller/log_poller_test.go | 14 ++- pkg/solana/logpoller/models.go | 10 +- pkg/solana/logpoller/test_helpers.go | 4 +- pkg/solana/logpoller/types.go | 9 +- pkg/solana/logpoller/types_test.go | 4 +- 14 files changed, 271 insertions(+), 158 deletions(-) delete mode 100644 pkg/solana/logpoller/discriminator.go diff --git a/pkg/solana/codec/codec_entry.go b/pkg/solana/codec/codec_entry.go index d3b459d57..cbcdab45f 100644 --- a/pkg/solana/codec/codec_entry.go +++ b/pkg/solana/codec/codec_entry.go @@ -44,11 +44,16 @@ func NewAccountEntry(offchainName string, idlTypes AccountIDLTypes, includeDiscr return nil, err } + var discriminator *Discriminator + if includeDiscriminator { + discriminator = NewDiscriminator(idlTypes.Account.Name, true) + } + return newEntry( offchainName, idlTypes.Account.Name, accCodec, - includeDiscriminator, + discriminator, mod, ), nil } @@ -64,7 +69,7 @@ func NewPDAEntry(offchainName string, pdaTypeDef PDATypeDef, mod codec.Modifier, offchainName, offchainName, // PDA seeds do not correlate to anything on-chain so reusing offchain name accCodec, - false, + nil, mod, ), nil } @@ -85,7 +90,7 @@ func NewInstructionArgsEntry(offChainName string, idlTypes InstructionArgsIDLTyp idlTypes.Instruction.Name, instructionCodecArgs, // Instruction arguments don't need a discriminator by default - false, + nil, mod, ), nil } @@ -101,11 +106,16 @@ func NewEventArgsEntry(offChainName string, idlTypes EventIDLTypes, includeDiscr return nil, err } + var discriminator *Discriminator + if includeDiscriminator { + discriminator = NewDiscriminator(idlTypes.Event.Name, false) + } + return newEntry( offChainName, idlTypes.Event.Name, eventCodec, - includeDiscriminator, + discriminator, mod, ), nil } @@ -113,18 +123,23 @@ func NewEventArgsEntry(offChainName string, idlTypes EventIDLTypes, includeDiscr func newEntry( genericName, chainSpecificName string, typeCodec commonencodings.TypeCodec, - includeDiscriminator bool, + discriminator *Discriminator, mod codec.Modifier, ) Entry { - return &entry{ - genericName: genericName, - chainSpecificName: chainSpecificName, - reflectType: typeCodec.GetType(), - typeCodec: typeCodec, - mod: ensureModifier(mod), - includeDiscriminator: includeDiscriminator, - discriminator: *NewDiscriminator(chainSpecificName), + e := &entry{ + genericName: genericName, + chainSpecificName: chainSpecificName, + reflectType: typeCodec.GetType(), + typeCodec: typeCodec, + mod: ensureModifier(mod), } + + if discriminator != nil { + e.discriminator = *discriminator + e.includeDiscriminator = true + } + + return e } func createRefs(idlTypes IdlTypeDefSlice, builder commonencodings.Builder) *codecRefs { @@ -175,8 +190,8 @@ func (e *entry) Decode(encoded []byte) (any, []byte, error) { } if !bytes.Equal(e.discriminator.hashPrefix, encoded[:discriminatorLength]) { - return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v for genericName: %q, chainSpecificName: %q", - commontypes.ErrInvalidType, encoded[:discriminatorLength], e.genericName, e.chainSpecificName) + return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v, expected %v, for genericName: %q, chainSpecificName: %q", + commontypes.ErrInvalidType, encoded[:discriminatorLength], e.discriminator.hashPrefix, e.genericName, e.chainSpecificName) } encoded = encoded[discriminatorLength:] diff --git a/pkg/solana/codec/codec_test.go b/pkg/solana/codec/codec_test.go index 50b0366ac..39690e57a 100644 --- a/pkg/solana/codec/codec_test.go +++ b/pkg/solana/codec/codec_test.go @@ -76,18 +76,26 @@ func (it *codecInterfaceTester) GetAccountString(i int) string { } func (it *codecInterfaceTester) EncodeFields(t *testing.T, request *EncodeRequest) []byte { - if request.TestOn == TestItemType || request.TestOn == testutils.TestEventItem { - return encodeFieldsOnItem(t, request) + if request.TestOn == TestItemType { + return encodeFieldsOnItem(t, request, true) + } else if request.TestOn == testutils.TestEventItem { + return encodeFieldsOnItem(t, request, false) } return encodeFieldsOnSliceOrArray(t, request) } -func encodeFieldsOnItem(t *testing.T, request *EncodeRequest) ocr2types.Report { +func encodeFieldsOnItem(t *testing.T, request *EncodeRequest, isAccount bool) ocr2types.Report { buf := new(bytes.Buffer) - // The underlying TestItemAsAccount adds a discriminator by default while being Borsh encoded. - if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil { - require.NoError(t, err) + // The underlying TestItem adds a discriminator by default while being Borsh encoded. + if isAccount { + if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil { + require.NoError(t, err) + } + } else { + if err := testutils.EncodeRequestToTestItemAsEvent(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil { + require.NoError(t, err) + } } return buf.Bytes() } diff --git a/pkg/solana/codec/discriminator.go b/pkg/solana/codec/discriminator.go index 9bc363ae7..f63a61739 100644 --- a/pkg/solana/codec/discriminator.go +++ b/pkg/solana/codec/discriminator.go @@ -12,9 +12,19 @@ import ( const discriminatorLength = 8 -func NewDiscriminator(name string) *Discriminator { - sum := sha256.Sum256([]byte("account:" + name)) - return &Discriminator{hashPrefix: sum[:discriminatorLength]} +func NewDiscriminator(name string, isAccount bool) *Discriminator { + return &Discriminator{hashPrefix: NewDiscriminatorHashPrefix(name, isAccount)} +} + +func NewDiscriminatorHashPrefix(name string, isAccount bool) []byte { + var sum [32]byte + if isAccount { + sum = sha256.Sum256([]byte("account:" + name)) + } else { + sum = sha256.Sum256([]byte("event:" + name)) + } + + return sum[:discriminatorLength] } type Discriminator struct { diff --git a/pkg/solana/codec/discriminator_test.go b/pkg/solana/codec/discriminator_test.go index 8a3ba95b9..7985ca0bd 100644 --- a/pkg/solana/codec/discriminator_test.go +++ b/pkg/solana/codec/discriminator_test.go @@ -17,7 +17,7 @@ func TestDiscriminator(t *testing.T) { t.Run("encode and decode return the discriminator", func(t *testing.T) { tmp := sha256.Sum256([]byte("account:Foo")) expected := tmp[:8] - c := codec.NewDiscriminator("Foo") + c := codec.NewDiscriminator("Foo", true) encoded, err := c.Encode(&expected, nil) require.NoError(t, err) require.Equal(t, expected, encoded) @@ -28,7 +28,7 @@ func TestDiscriminator(t *testing.T) { }) t.Run("encode returns an error if the discriminator is invalid", func(t *testing.T) { - c := codec.NewDiscriminator("Foo") + c := codec.NewDiscriminator("Foo", true) _, err := c.Encode(&[]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, nil) require.True(t, errors.Is(err, types.ErrInvalidType)) }) @@ -36,7 +36,7 @@ func TestDiscriminator(t *testing.T) { t.Run("encode injects the discriminator if it's not provided", func(t *testing.T) { tmp := sha256.Sum256([]byte("account:Foo")) expected := tmp[:8] - c := codec.NewDiscriminator("Foo") + c := codec.NewDiscriminator("Foo", true) encoded, err := c.Encode(nil, nil) require.NoError(t, err) require.Equal(t, expected, encoded) @@ -46,37 +46,37 @@ func TestDiscriminator(t *testing.T) { }) t.Run("decode returns an error if the encoded value is too short", func(t *testing.T) { - c := codec.NewDiscriminator("Foo") + c := codec.NewDiscriminator("Foo", true) _, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) require.True(t, errors.Is(err, types.ErrInvalidEncoding)) }) t.Run("decode returns an error if the discriminator is invalid", func(t *testing.T) { - c := codec.NewDiscriminator("Foo") + c := codec.NewDiscriminator("Foo", true) _, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}) require.True(t, errors.Is(err, types.ErrInvalidEncoding)) }) t.Run("encode returns an error if the value is not a byte slice", func(t *testing.T) { - c := codec.NewDiscriminator("Foo") + c := codec.NewDiscriminator("Foo", true) _, err := c.Encode(42, nil) require.True(t, errors.Is(err, types.ErrInvalidType)) }) t.Run("GetType returns the type of the discriminator", func(t *testing.T) { - c := codec.NewDiscriminator("Foo") + c := codec.NewDiscriminator("Foo", true) require.Equal(t, reflect.TypeOf(&[]byte{}), c.GetType()) }) t.Run("Size returns the length of the discriminator", func(t *testing.T) { - c := codec.NewDiscriminator("Foo") + c := codec.NewDiscriminator("Foo", true) size, err := c.Size(0) require.NoError(t, err) require.Equal(t, 8, size) }) t.Run("FixedSize returns the length of the discriminator", func(t *testing.T) { - c := codec.NewDiscriminator("Foo") + c := codec.NewDiscriminator("Foo", true) size, err := c.FixedSize() require.NoError(t, err) require.Equal(t, 8, size) diff --git a/pkg/solana/codec/solana.go b/pkg/solana/codec/solana.go index 3323ced10..ce9e277db 100644 --- a/pkg/solana/codec/solana.go +++ b/pkg/solana/codec/solana.go @@ -148,32 +148,11 @@ func WrapItemType(forEncoding bool, contractName, itemType string, readType Chai return fmt.Sprintf("output.%s.%s.%s", readType, contractName, itemType) } -// NewIDLAccountCodec is for Anchor custom types +// TODO Deprecate and remove this. func NewIDLAccountCodec(idl IDL, builder commonencodings.Builder) (commontypes.RemoteCodec, error) { return newIDLCoded(idl, builder, idl.Accounts, true) } -func NewIDLInstructionsCodec(idl IDL, builder commonencodings.Builder) (commontypes.RemoteCodec, error) { - typeCodecs := make(commonencodings.LenientCodecFromTypeCodec) - refs := &codecRefs{ - builder: builder, - codecs: make(map[string]commonencodings.TypeCodec), - typeDefs: idl.Types, - dependencies: make(map[string][]string), - } - - for _, instruction := range idl.Instructions { - name, instCodec, err := asStruct(instruction.Args, refs, instruction.Name, false, false) - if err != nil { - return nil, err - } - - typeCodecs[name] = instCodec - } - - return typeCodecs, nil -} - func NewNamedModifierCodec(original commontypes.RemoteCodec, itemType string, modifier commoncodec.Modifier) (commontypes.RemoteCodec, error) { mod, err := commoncodec.NewByItemTypeModifier(map[string]commoncodec.Modifier{itemType: modifier}) if err != nil { @@ -190,6 +169,7 @@ func NewNamedModifierCodec(original commontypes.RemoteCodec, itemType string, mo return modCodec, err } +// TODO Deprecate and remove this. func NewIDLDefinedTypesCodec(idl IDL, builder commonencodings.Builder) (commontypes.RemoteCodec, error) { return newIDLCoded(idl, builder, idl.Types, false) } @@ -254,6 +234,7 @@ type codecRefs struct { func createCodecType( def IdlTypeDef, refs *codecRefs, + // TODO Deprecated includeDiscriminator is not needed here after NewIDLAccountCodec gets cleaned up includeDiscriminator bool, ) (string, commonencodings.TypeCodec, error) { name := def.Name @@ -275,6 +256,7 @@ func asStruct( fields []IdlField, refs *codecRefs, name string, // name is the struct name and can be used in dependency checks + // TODO Deprecated includeDiscriminator is not needed here after NewIDLAccountCodec gets cleaned up includeDiscriminator bool, isInstructionArgs bool, ) (string, commonencodings.TypeCodec, error) { @@ -286,7 +268,7 @@ func asStruct( named := make([]commonencodings.NamedTypeCodec, len(fields)+desLen) if includeDiscriminator { - named[0] = commonencodings.NamedTypeCodec{Name: "Discriminator" + name, Codec: NewDiscriminator(name)} + named[0] = commonencodings.NamedTypeCodec{Name: "Discriminator" + name, Codec: NewDiscriminator(name, true)} } for idx, field := range fields { @@ -491,36 +473,3 @@ func saveDependency(refs *codecRefs, parent, child string) { refs.dependencies[parent] = append(deps, child) } -func NewIDLEventCodec(idl IDL, builder commonencodings.Builder) (commontypes.RemoteCodec, error) { - typeCodecs := make(commonencodings.LenientCodecFromTypeCodec) - refs := &codecRefs{ - builder: builder, - codecs: make(map[string]commonencodings.TypeCodec), - typeDefs: idl.Types, - dependencies: make(map[string][]string), - } - - for _, event := range idl.Events { - name, instCodec, err := asStruct(eventFieldsAsStandardFields(event.Fields), refs, event.Name, false, false) - if err != nil { - return nil, err - } - - typeCodecs[name] = instCodec - } - - return typeCodecs, nil -} - -func eventFieldsAsStandardFields(event []IdlEventField) []IdlField { - output := make([]IdlField, len(event)) - - for idx := range output { - output[idx] = IdlField{ - Name: event[idx].Name, - Type: event[idx].Type, - } - } - - return output -} diff --git a/pkg/solana/codec/testutils/types.go b/pkg/solana/codec/testutils/types.go index 3c52adb0f..4a61b24ca 100644 --- a/pkg/solana/codec/testutils/types.go +++ b/pkg/solana/codec/testutils/types.go @@ -170,11 +170,11 @@ type TestItemAsAccount struct { NestedStaticStruct NestedStatic } -var TestItemDiscriminator = [8]byte{148, 105, 105, 155, 26, 167, 212, 149} +var TestItemAsAccountDiscriminator = [8]byte{148, 105, 105, 155, 26, 167, 212, 149} func (obj TestItemAsAccount) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) { // Write account discriminator: - err = encoder.WriteBytes(TestItemDiscriminator[:], false) + err = encoder.WriteBytes(TestItemAsAccountDiscriminator[:], false) if err != nil { return err } @@ -233,7 +233,7 @@ func (obj *TestItemAsAccount) UnmarshalWithDecoder(decoder *agbinary.Decoder) er if err != nil { return err } - if !discriminator.Equal(TestItemDiscriminator[:]) { + if !discriminator.Equal(TestItemAsAccountDiscriminator[:]) { return fmt.Errorf( "wrong discriminator: wanted %s, got %s", "[148 105 105 155 26 167 212 149]", @@ -288,6 +288,136 @@ func (obj *TestItemAsAccount) UnmarshalWithDecoder(decoder *agbinary.Decoder) er return nil } +type TestItemAsEvent struct { + Field int32 + OracleID uint8 + OracleIDs [32]uint8 + AccountStruct AccountStruct + Accounts []solana.PublicKey + DifferentField string + BigField agbinary.Int128 + NestedDynamicStruct NestedDynamic + NestedStaticStruct NestedStatic +} + +var TestItemAsEventDiscriminator = [8]byte{119, 183, 160, 247, 84, 104, 222, 251} + +func (obj TestItemAsEvent) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) { + // Write event discriminator: + err = encoder.WriteBytes(TestItemAsEventDiscriminator[:], false) + if err != nil { + return err + } + // Serialize `Field` param: + err = encoder.Encode(obj.Field) + if err != nil { + return err + } + // Serialize `OracleID` param: + err = encoder.Encode(obj.OracleID) + if err != nil { + return err + } + // Serialize `OracleIDs` param: + err = encoder.Encode(obj.OracleIDs) + if err != nil { + return err + } + // Serialize `AccountStruct` param: + err = encoder.Encode(obj.AccountStruct) + if err != nil { + return err + } + // Serialize `Accounts` param: + err = encoder.Encode(obj.Accounts) + if err != nil { + return err + } + // Serialize `DifferentField` param: + err = encoder.Encode(obj.DifferentField) + if err != nil { + return err + } + // Serialize `BigField` param: + err = encoder.Encode(obj.BigField) + if err != nil { + return err + } + // Serialize `NestedDynamicStruct` param: + err = encoder.Encode(obj.NestedDynamicStruct) + if err != nil { + return err + } + // Serialize `NestedStaticStruct` param: + err = encoder.Encode(obj.NestedStaticStruct) + if err != nil { + return err + } + return nil +} + +func (obj *TestItemAsEvent) UnmarshalWithDecoder(decoder *agbinary.Decoder) error { + // Read and check account discriminator: + { + discriminator, err := decoder.ReadTypeID() + if err != nil { + return err + } + if !discriminator.Equal(TestItemAsEventDiscriminator[:]) { + return fmt.Errorf( + "wrong discriminator: wanted %s, got %s", + "[119, 183, 160, 247, 84, 104, 222, 251]", + fmt.Sprint(discriminator[:])) + } + } + // Deserialize `Field`: + err := decoder.Decode(&obj.Field) + if err != nil { + return err + } + // Deserialize `OracleID`: + err = decoder.Decode(&obj.OracleID) + if err != nil { + return err + } + // Deserialize `OracleIDs`: + err = decoder.Decode(&obj.OracleIDs) + if err != nil { + return err + } + // Deserialize `AccountStruct`: + err = decoder.Decode(&obj.AccountStruct) + if err != nil { + return err + } + // Deserialize `Accounts`: + err = decoder.Decode(&obj.Accounts) + if err != nil { + return err + } + // Deserialize `DifferentField`: + err = decoder.Decode(&obj.DifferentField) + if err != nil { + return err + } + // Deserialize `BigField`: + err = decoder.Decode(&obj.BigField) + if err != nil { + return err + } + // Deserialize `NestedDynamicStruct`: + err = decoder.Decode(&obj.NestedDynamicStruct) + if err != nil { + return err + } + // Deserialize `NestedStaticStruct`: + err = decoder.Decode(&obj.NestedStaticStruct) + if err != nil { + return err + } + return nil +} + type TestItemAsArgs struct { Field int32 OracleID uint8 @@ -577,6 +707,20 @@ func EncodeRequestToTestItemAsAccount(testStruct interfacetests.TestStruct) Test } } +func EncodeRequestToTestItemAsEvent(testStruct interfacetests.TestStruct) TestItemAsEvent { + return TestItemAsEvent{ + Field: *testStruct.Field, + OracleID: uint8(testStruct.OracleID), + OracleIDs: getOracleIDs(testStruct), + AccountStruct: getAccountStruct(testStruct), + Accounts: getAccounts(testStruct), + DifferentField: testStruct.DifferentField, + BigField: bigIntToBinInt128(testStruct.BigField), + NestedDynamicStruct: getNestedDynamic(testStruct), + NestedStaticStruct: getNestedStatic(testStruct), + } +} + func EncodeRequestToTestItemAsArgs(testStruct interfacetests.TestStruct) TestItemAsArgs { return TestItemAsArgs{ Field: *testStruct.Field, diff --git a/pkg/solana/logpoller/discriminator.go b/pkg/solana/logpoller/discriminator.go deleted file mode 100644 index 812057a1c..000000000 --- a/pkg/solana/logpoller/discriminator.go +++ /dev/null @@ -1,14 +0,0 @@ -package logpoller - -import ( - "crypto/sha256" - "fmt" -) - -const DiscriminatorLength = 8 - -func Discriminator(namespace, name string) [DiscriminatorLength]byte { - h := sha256.New() - h.Write([]byte(fmt.Sprintf("%s:%s", namespace, name))) - return [DiscriminatorLength]byte(h.Sum(nil)[:DiscriminatorLength]) -} diff --git a/pkg/solana/logpoller/filters.go b/pkg/solana/logpoller/filters.go index dc1d52252..499adceee 100644 --- a/pkg/solana/logpoller/filters.go +++ b/pkg/solana/logpoller/filters.go @@ -114,22 +114,25 @@ func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error { fl.removeFilterFromIndexes(*existingFilter) } - filterID, err := fl.orm.InsertFilter(ctx, filter) + cEntry, err := codec.NewEventArgsEntry(filter.EventName, filter.EventIdl.EventIDLTypes, true, nil, binary.LittleEndian()) if err != nil { - return fmt.Errorf("failed to insert filter: %w", err) + return err } - filter.ID = filterID - - idl := codec.IDL{ - Events: []codec.IdlEvent{filter.EventIdl.IdlEvent}, - Types: filter.EventIdl.IdlTypeDefSlice, - } - fl.decoders[filter.ID], err = codec.NewIDLEventCodec(idl, binary.LittleEndian()) + decoderTypes := codec.ParsedTypes{DecoderDefs: map[string]codec.Entry{filter.EventName: cEntry}} + decoder, err := decoderTypes.ToCodec() if err != nil { return fmt.Errorf("failed to create event decoder: %w", err) } + filterID, err := fl.orm.InsertFilter(ctx, filter) + if err != nil { + return fmt.Errorf("failed to insert filter: %w", err) + } + + filter.ID = filterID + + fl.decoders[filter.ID] = decoder fl.filtersByName[filter.Name] = filter.ID fl.filtersByID[filter.ID] = &filter filtersForAddress, ok := fl.filtersByAddress[filter.Address] @@ -151,9 +154,7 @@ func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error { programID := filter.Address.ToSolana().String() fl.knownPrograms[programID]++ - - discriminatorHead := filter.Discriminator()[:10] - fl.knownDiscriminators[discriminatorHead]++ + fl.knownDiscriminators[filter.Discriminator()]++ return nil } @@ -229,7 +230,7 @@ func (fl *filters) removeFilterFromIndexes(filter Filter) { } } - discriminatorHead := filter.Discriminator()[:10] + discriminatorHead := filter.Discriminator() if refcount, ok := fl.knownDiscriminators[discriminatorHead]; ok { refcount-- if refcount > 0 { @@ -302,6 +303,15 @@ func (fl *filters) MatchingFiltersForEncodedEvent(event ProgramEvent) iter.Seq[F if len(event.Data) < 12 { return nil } + + discriminator, err := base64.StdEncoding.DecodeString(event.Data[:12]) + if err != nil { + fl.lggr.Errorw("failed to decode event discriminator", "event", event, "err", err) + return nil + } + + discriminator = discriminator[:8] + isKnown := func() (ok bool) { fl.filtersMutex.RLock() defer fl.filtersMutex.RUnlock() @@ -315,7 +325,7 @@ func (fl *filters) MatchingFiltersForEncodedEvent(event ProgramEvent) iter.Seq[F // discriminators if the first 10 characters don't match. If it passes that initial test, we base64-decode the // first 12 characters, and use the first 8 bytes of that as the event sig to call MatchingFilters. The address // also needs to be base58-decoded to pass to MatchingFilters - _, ok = fl.knownDiscriminators[event.Data[:10]] + _, ok = fl.knownDiscriminators[base64.StdEncoding.EncodeToString(discriminator)] return ok } @@ -329,16 +339,7 @@ func (fl *filters) MatchingFiltersForEncodedEvent(event ProgramEvent) iter.Seq[F return nil } - // Decoding first 12 characters will give us the first 9 bytes of binary data - // The first 8 of those is the discriminator - decoded, err := base64.StdEncoding.DecodeString(event.Data[:12]) - if err != nil || len(decoded) < 8 { - fl.lggr.Errorw("failed to decode event data", "EventProgram", event) - return nil - } - eventSig := EventSignature(decoded[:8]) - - return fl.matchingFilters(PublicKey(addr), eventSig) + return fl.matchingFilters(PublicKey(addr), EventSignature(discriminator)) } // GetFiltersToBackfill - returns copy of backfill queue diff --git a/pkg/solana/logpoller/log_poller.go b/pkg/solana/logpoller/log_poller.go index 2ab672eb9..96893ff56 100644 --- a/pkg/solana/logpoller/log_poller.go +++ b/pkg/solana/logpoller/log_poller.go @@ -152,16 +152,17 @@ func (lp *Service) Process(ctx context.Context, programEvent ProgramEvent) (err TxHash: Signature(blockData.TransactionHash), } - eventData, decodeErr := base64.StdEncoding.DecodeString(programEvent.Data) - if decodeErr != nil { - return decodeErr + log.Data, err = base64.StdEncoding.DecodeString(programEvent.Data) + if err != nil { + return err } - if len(eventData) < 8 { - err = fmt.Errorf("Assumption violation: %w, log.Data=%s", ErrMissingDiscriminator, log.Data) + + // TODO isn't discrimintaor already checked above in MatchingFiltersForEncodedEvent? + if len(log.Data) < 8 { + err = fmt.Errorf("assumption violation: %w, log.Data=%s", ErrMissingDiscriminator, log.Data) lp.lggr.Criticalw(err.Error()) return err } - log.Data = eventData[8:] log.SubkeyValues = make([]IndexedValue, 0, len(filter.SubkeyPaths)) for _, path := range filter.SubkeyPaths { @@ -169,7 +170,7 @@ func (lp *Service) Process(ctx context.Context, programEvent ProgramEvent) (err if decodeSubKeyErr != nil { return decodeSubKeyErr } - indexedVal, newIndexedValErr := NewIndexedValue(subKeyVal) + indexedVal, newIndexedValErr := newIndexedValue(subKeyVal) if newIndexedValErr != nil { return newIndexedValErr } diff --git a/pkg/solana/logpoller/log_poller_test.go b/pkg/solana/logpoller/log_poller_test.go index 033f05b19..d87403266 100644 --- a/pkg/solana/logpoller/log_poller_test.go +++ b/pkg/solana/logpoller/log_poller_test.go @@ -250,14 +250,14 @@ func TestProcess(t *testing.T) { addr := newRandomPublicKey(t) eventName := "myEvent" - eventSig := Discriminator("event", eventName) + eventSig := EventSignature(codec.NewDiscriminatorHashPrefix(eventName, false)) event := struct { A int64 B string }{55, "hello"} - subkeyValA, err := NewIndexedValue(event.A) + subkeyValA, err := newIndexedValue(event.A) require.NoError(t, err) - subkeyValB, err := NewIndexedValue(event.B) + subkeyValB, err := newIndexedValue(event.B) require.NoError(t, err) filterID := rand.Int63() @@ -276,6 +276,7 @@ func TestProcess(t *testing.T) { expectedLog.Data, err = bin.MarshalBorsh(&event) require.NoError(t, err) + expectedLog.Data = append(eventSig[:], expectedLog.Data...) ev := ProgramEvent{ Program: addr.ToSolana().String(), BlockData: BlockData{ @@ -287,7 +288,7 @@ func TestProcess(t *testing.T) { TransactionIndex: txIndex, TransactionLogIndex: txLogIndex, }, - Data: base64.StdEncoding.EncodeToString(append(eventSig[:], expectedLog.Data...)), + Data: base64.StdEncoding.EncodeToString(expectedLog.Data), } orm := NewMockORM(t) @@ -304,7 +305,7 @@ func TestProcess(t *testing.T) { require.NoError(t, err) idl := EventIdl{ - codec.IdlEvent{ + EventIDLTypes: codec.EventIDLTypes{Event: codec.IdlEvent{ Name: "myEvent", Fields: []codec.IdlEventField{{ Name: "A", @@ -314,7 +315,8 @@ func TestProcess(t *testing.T) { Type: idlTypeString, }}, }, - []codec.IdlTypeDef{}, + Types: []codec.IdlTypeDef{}, + }, } filter := Filter{ diff --git a/pkg/solana/logpoller/models.go b/pkg/solana/logpoller/models.go index 2fe406d74..820312d67 100644 --- a/pkg/solana/logpoller/models.go +++ b/pkg/solana/logpoller/models.go @@ -2,8 +2,9 @@ package logpoller import ( "encoding/base64" - "fmt" "time" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" ) type Filter struct { @@ -30,12 +31,7 @@ func (f Filter) MatchSameLogs(other Filter) bool { // // This is the base64 encoding of the [8]byte discriminator returned by utils.Discriminator func (f Filter) Discriminator() string { - d := Discriminator("event", f.EventName) - b64encoded := base64.StdEncoding.EncodeToString(d[:]) - if len(b64encoded) != 12 { - panic(fmt.Sprintf("Assumption Violation: expected encoding/base64 to return 12 character base64-encoding, got %d characters", len(b64encoded))) - } - return b64encoded + return base64.StdEncoding.EncodeToString(codec.NewDiscriminatorHashPrefix(f.EventName, false)) } type Log struct { diff --git a/pkg/solana/logpoller/test_helpers.go b/pkg/solana/logpoller/test_helpers.go index 8511ae1ac..8f133bab2 100644 --- a/pkg/solana/logpoller/test_helpers.go +++ b/pkg/solana/logpoller/test_helpers.go @@ -7,6 +7,8 @@ import ( "github.com/gagliardetto/solana-go" "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" ) func newRandomPublicKey(t *testing.T) PublicKey { @@ -36,7 +38,7 @@ func newRandomLog(t *testing.T, filterID int64, chainID string, eventName string BlockNumber: rand.Int63n(1000000), BlockTimestamp: time.Unix(1731590113, 0).UTC(), Address: PublicKey(pubKey), - EventSig: Discriminator("event", eventName), + EventSig: EventSignature(codec.NewDiscriminatorHashPrefix(eventName, false)), SubkeyValues: []IndexedValue{{3, 2, 1}, {1}, {1, 2}, pubKey.Bytes()}, TxHash: Signature(signature), Data: data, diff --git a/pkg/solana/logpoller/types.go b/pkg/solana/logpoller/types.go index 5b75ff9a7..fcd38a312 100644 --- a/pkg/solana/logpoller/types.go +++ b/pkg/solana/logpoller/types.go @@ -127,8 +127,7 @@ type Decoder interface { } type EventIdl struct { - codec.IdlEvent - codec.IdlTypeDefSlice + codec.EventIDLTypes } func (e *EventIdl) Scan(src interface{}) error { @@ -137,8 +136,8 @@ func (e *EventIdl) Scan(src interface{}) error { func (e EventIdl) Value() (driver.Value, error) { return json.Marshal(map[string]any{ - "IdlEvent": e.IdlEvent, - "IdlTypeDefSlice": e.IdlTypeDefSlice, + "IdlEvent": e.EventIDLTypes.Event, + "IdlTypeDefSlice": e.EventIDLTypes.Types, }) } @@ -192,7 +191,7 @@ func (v *IndexedValue) FromFloat64(f float64) { v.FromUint64(math.MaxInt64 + 1 - math.Float64bits(f)) } -func NewIndexedValue(typedVal any) (iVal IndexedValue, err error) { +func newIndexedValue(typedVal any) (iVal IndexedValue, err error) { // handle 2 simplest cases first switch t := typedVal.(type) { case []byte: diff --git a/pkg/solana/logpoller/types_test.go b/pkg/solana/logpoller/types_test.go index b7ed36c6d..499ce53e5 100644 --- a/pkg/solana/logpoller/types_test.go +++ b/pkg/solana/logpoller/types_test.go @@ -35,9 +35,9 @@ func TestIndexedValue(t *testing.T) { } for _, c := range cases { t.Run(c.typeName, func(t *testing.T) { - iVal1, err := NewIndexedValue(c.lower) + iVal1, err := newIndexedValue(c.lower) require.NoError(t, err) - iVal2, err := NewIndexedValue(c.higher) + iVal2, err := newIndexedValue(c.higher) require.NoError(t, err) assert.Less(t, iVal1, iVal2) }) From 0544d8a7afed3b6ea336a6dfd54f791a53cd0f5b Mon Sep 17 00:00:00 2001 From: ilija Date: Thu, 23 Jan 2025 17:08:41 +0100 Subject: [PATCH 2/8] Optimise log poller discriminator comparison --- pkg/solana/codec/discriminator.go | 43 +++++++ .../codec/discriminator_extractor_test.go | 105 ++++++++++++++++++ pkg/solana/logpoller/filters.go | 46 ++++---- pkg/solana/logpoller/models.go | 9 +- 4 files changed, 171 insertions(+), 32 deletions(-) create mode 100644 pkg/solana/codec/discriminator_extractor_test.go diff --git a/pkg/solana/codec/discriminator.go b/pkg/solana/codec/discriminator.go index f63a61739..0e366c257 100644 --- a/pkg/solana/codec/discriminator.go +++ b/pkg/solana/codec/discriminator.go @@ -79,3 +79,46 @@ func (d Discriminator) Size(_ int) (int, error) { func (d Discriminator) FixedSize() (int, error) { return discriminatorLength, nil } + +type DiscriminatorExtractor struct { + b64Index [256]byte +} + +// NewDiscriminatorExtractor is optimised to extract discriminators from base64 encoded strings faster than the base64 lib. +func NewDiscriminatorExtractor() DiscriminatorExtractor { + instance := DiscriminatorExtractor{} + const base64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + for i := 0; i < len(base64Chars); i++ { + instance.b64Index[base64Chars[i]] = byte(i) + } + return instance +} + +// Extract most optimally (around 40% faster than std) decodes the first 8 bytes of a base64 encoded string, which corresponds to a Solana discriminator. +// Extractor expects input of > 12 characters which 8 bytes are extracted from, if the input string is less than 12 characters, this will panic. +// If string contains non-Base64 characters (e.g., !, @, space) map to index 0 (ASCII 'A'), and won't be accurate. +func (e *DiscriminatorExtractor) Extract(data string) []byte { + var decodeBuffer [9]byte + d := decodeBuffer[:9] + s := data[:12] + + // base64 decode + for i := 0; i < 3; i++ { + // decode base64 chars into associated byte + c1 := e.b64Index[s[0]] + c2 := e.b64Index[s[1]] + c3 := e.b64Index[s[2]] + c4 := e.b64Index[s[3]] + + // reconstruct raw bytes + d[0] = (c1 << 2) | (c2 >> 4) + d[1] = (c2 << 4) | (c3 >> 2) + d[2] = (c3 << 6) | c4 + + // next 3 bytes and next 4 characters + d = d[3:] + s = s[4:] + } + + return decodeBuffer[:discriminatorLength] +} diff --git a/pkg/solana/codec/discriminator_extractor_test.go b/pkg/solana/codec/discriminator_extractor_test.go new file mode 100644 index 000000000..e67fa80d2 --- /dev/null +++ b/pkg/solana/codec/discriminator_extractor_test.go @@ -0,0 +1,105 @@ +package codec + +import ( + "encoding/base64" + "fmt" + mathrand "math/rand" + "testing" + + "github.com/stretchr/testify/require" +) + +func FuzzExtractorHappyPath(f *testing.F) { + // Seed with valid base64 discriminators + seeds := []struct { + Data string + }{ + {"SGVbG8gd29ybGQ"}, // "Hello world!" + {"AAAAAAAAAAAA"}, // Zero bytes + {"////////////"}, // Max value bytes + {"QUJDREVGR0hJSktM"}, // "ABCDEFGHIJKL" + } + + for _, seed := range seeds { + f.Add(seed.Data) + } + + extractor := NewDiscriminatorExtractor() + f.Fuzz(func(t *testing.T, testString string) { + if len(testString) < 12 { + t.Fatal(fmt.Sprintf("test string is shorter than 12 %s", testString)) + } + + data := testString[:12] + std, err := base64.StdEncoding.DecodeString(data) + if err != nil { + t.Fatal(fmt.Sprintf("failed to decode test string %s with stdlib", data)) + } + + if len(std) < 8 { + t.Fatal("stdlib decoded < 8 bytes") + } + + require.Equal(t, std[:8], extractor.Extract(data)) + }) +} + +func TestDiscriminatorExtractorBase64Indexes(t *testing.T) { + extractor := NewDiscriminatorExtractor() + const base64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + for i, c := range base64Chars { + if extractor.b64Index[c] != byte(i) { + t.Errorf("incorrect index for character %q: expected %d, got %d", c, i, extractor.b64Index[c]) + } + } +} + +func TestExtractor_Extract_ShortInput(t *testing.T) { + extractor := NewDiscriminatorExtractor() + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for short input, but none occurred") + } + }() + + // Attempt with 11-character string (needs at least 12) + extractor.Extract("short_input") +} + +// Custom extractor is around 40% faster than using stdlib +func BenchmarkDiscriminatorExtraction(b *testing.B) { + generateDiscriminatorDecodeTestData := func(numTestEntries int) []string { + // corresponds to a 12 character base64 encoded string + entrySize := int64(8) + var testData []string + // Create seeded random source + r := mathrand.New(mathrand.NewSource(entrySize)) + for range numTestEntries { + data := make([]byte, entrySize) + _, _ = r.Read(data) + + testData = append(testData, base64.StdEncoding.EncodeToString(data)) + } + + return testData + } + + b.Run("Standard lib Base64", func(b *testing.B) { + testData := generateDiscriminatorDecodeTestData(b.N) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = base64.StdEncoding.DecodeString(testData[i]) + } + }) + + b.Run("CustomExtractor", func(b *testing.B) { + testData := generateDiscriminatorDecodeTestData(b.N) + extractor := NewDiscriminatorExtractor() + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + extractor.Extract(testData[i]) + } + }) +} diff --git a/pkg/solana/logpoller/filters.go b/pkg/solana/logpoller/filters.go index 499adceee..5e98a3612 100644 --- a/pkg/solana/logpoller/filters.go +++ b/pkg/solana/logpoller/filters.go @@ -2,7 +2,6 @@ package logpoller import ( "context" - "encoding/base64" "errors" "fmt" "iter" @@ -24,24 +23,26 @@ type filters struct { orm ORM lggr logger.SugaredLogger - filtersByID map[int64]*Filter - filtersByName map[string]int64 - filtersByAddress map[PublicKey]map[EventSignature]map[int64]struct{} - filtersToBackfill map[int64]struct{} - filtersToDelete map[int64]Filter - filtersMutex sync.RWMutex - loadedFilters atomic.Bool - knownPrograms map[string]uint // fast lookup to see if a base58-encoded ProgramID matches any registered filters - knownDiscriminators map[string]uint // fast lookup by first 10 characters (60-bits) of a base64-encoded discriminator - seqNums map[int64]int64 - decoders map[int64]Decoder + filtersByID map[int64]*Filter + filtersByName map[string]int64 + filtersByAddress map[PublicKey]map[EventSignature]map[int64]struct{} + filtersToBackfill map[int64]struct{} + filtersToDelete map[int64]Filter + filtersMutex sync.RWMutex + loadedFilters atomic.Bool + knownPrograms map[string]uint // fast lookup to see if a base58-encoded ProgramID matches any registered filters + knownDiscriminators map[string]uint // fast lookup by first 10 characters (60-bits) of a base64-encoded discriminator + seqNums map[int64]int64 + decoders map[int64]Decoder + discriminatorExtractor codec.DiscriminatorExtractor } func newFilters(lggr logger.SugaredLogger, orm ORM) *filters { return &filters{ - orm: orm, - lggr: lggr, - decoders: make(map[int64]Decoder), + orm: orm, + lggr: lggr, + decoders: make(map[int64]Decoder), + discriminatorExtractor: codec.NewDiscriminatorExtractor(), } } @@ -154,7 +155,7 @@ func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error { programID := filter.Address.ToSolana().String() fl.knownPrograms[programID]++ - fl.knownDiscriminators[filter.Discriminator()]++ + fl.knownDiscriminators[filter.DiscriminatorRawBytes()]++ return nil } @@ -230,7 +231,7 @@ func (fl *filters) removeFilterFromIndexes(filter Filter) { } } - discriminatorHead := filter.Discriminator() + discriminatorHead := filter.DiscriminatorRawBytes() if refcount, ok := fl.knownDiscriminators[discriminatorHead]; ok { refcount-- if refcount > 0 { @@ -304,14 +305,7 @@ func (fl *filters) MatchingFiltersForEncodedEvent(event ProgramEvent) iter.Seq[F return nil } - discriminator, err := base64.StdEncoding.DecodeString(event.Data[:12]) - if err != nil { - fl.lggr.Errorw("failed to decode event discriminator", "event", event, "err", err) - return nil - } - - discriminator = discriminator[:8] - + discriminator := fl.discriminatorExtractor.Extract(event.Data) isKnown := func() (ok bool) { fl.filtersMutex.RLock() defer fl.filtersMutex.RUnlock() @@ -325,7 +319,7 @@ func (fl *filters) MatchingFiltersForEncodedEvent(event ProgramEvent) iter.Seq[F // discriminators if the first 10 characters don't match. If it passes that initial test, we base64-decode the // first 12 characters, and use the first 8 bytes of that as the event sig to call MatchingFilters. The address // also needs to be base58-decoded to pass to MatchingFilters - _, ok = fl.knownDiscriminators[base64.StdEncoding.EncodeToString(discriminator)] + _, ok = fl.knownDiscriminators[string(discriminator)] return ok } diff --git a/pkg/solana/logpoller/models.go b/pkg/solana/logpoller/models.go index 820312d67..e0b470348 100644 --- a/pkg/solana/logpoller/models.go +++ b/pkg/solana/logpoller/models.go @@ -1,7 +1,6 @@ package logpoller import ( - "encoding/base64" "time" "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" @@ -27,11 +26,9 @@ func (f Filter) MatchSameLogs(other Filter) bool { f.EventIdl.Equal(other.EventIdl) && f.SubkeyPaths.Equal(other.SubkeyPaths) } -// Discriminator returns a 12 character base64-encoded string -// -// This is the base64 encoding of the [8]byte discriminator returned by utils.Discriminator -func (f Filter) Discriminator() string { - return base64.StdEncoding.EncodeToString(codec.NewDiscriminatorHashPrefix(f.EventName, false)) +// DiscriminatorRawBytes returns raw discriminator bytes as a string, this string is not base64 encoded. +func (f Filter) DiscriminatorRawBytes() string { + return string(codec.NewDiscriminatorHashPrefix(f.EventName, false)) } type Log struct { From 8b0e6fa3056234c232d457f136e636ad6b26ae63 Mon Sep 17 00:00:00 2001 From: ilija Date: Thu, 23 Jan 2025 17:10:51 +0100 Subject: [PATCH 3/8] Remove unnecessary comment in MatchingFiltersForEncodedEvent --- pkg/solana/logpoller/filters.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pkg/solana/logpoller/filters.go b/pkg/solana/logpoller/filters.go index 5e98a3612..1648863be 100644 --- a/pkg/solana/logpoller/filters.go +++ b/pkg/solana/logpoller/filters.go @@ -314,11 +314,6 @@ func (fl *filters) MatchingFiltersForEncodedEvent(event ProgramEvent) iter.Seq[F return ok } - // The first 64-bits of the event data is the event sig. Because it's base64 encoded, this corresponds to - // the first 10 characters plus 4 bits of the 11th character. We can quickly rule it out as not matching any known - // discriminators if the first 10 characters don't match. If it passes that initial test, we base64-decode the - // first 12 characters, and use the first 8 bytes of that as the event sig to call MatchingFilters. The address - // also needs to be base58-decoded to pass to MatchingFilters _, ok = fl.knownDiscriminators[string(discriminator)] return ok } From a8e09838dd8b25e396889f3fb8ef73c99aabefa3 Mon Sep 17 00:00:00 2001 From: ilija Date: Thu, 23 Jan 2025 17:22:39 +0100 Subject: [PATCH 4/8] lint --- pkg/solana/codec/discriminator_extractor_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pkg/solana/codec/discriminator_extractor_test.go b/pkg/solana/codec/discriminator_extractor_test.go index e67fa80d2..66459d791 100644 --- a/pkg/solana/codec/discriminator_extractor_test.go +++ b/pkg/solana/codec/discriminator_extractor_test.go @@ -2,7 +2,6 @@ package codec import ( "encoding/base64" - "fmt" mathrand "math/rand" "testing" @@ -27,13 +26,13 @@ func FuzzExtractorHappyPath(f *testing.F) { extractor := NewDiscriminatorExtractor() f.Fuzz(func(t *testing.T, testString string) { if len(testString) < 12 { - t.Fatal(fmt.Sprintf("test string is shorter than 12 %s", testString)) + t.Fatalf("test string is shorter than 12 %s", testString) } data := testString[:12] std, err := base64.StdEncoding.DecodeString(data) if err != nil { - t.Fatal(fmt.Sprintf("failed to decode test string %s with stdlib", data)) + t.Fatalf("failed to decode test string %s with stdlib", data) } if len(std) < 8 { From d578e2faf8fe4ab969cedf5c3a6bcea759287d92 Mon Sep 17 00:00:00 2001 From: ilija Date: Thu, 23 Jan 2025 17:34:34 +0100 Subject: [PATCH 5/8] lint discriminator extractor tests --- pkg/solana/codec/discriminator.go | 2 +- .../codec/discriminator_extractor_test.go | 19 +++++-------------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/pkg/solana/codec/discriminator.go b/pkg/solana/codec/discriminator.go index 0e366c257..65b850869 100644 --- a/pkg/solana/codec/discriminator.go +++ b/pkg/solana/codec/discriminator.go @@ -81,7 +81,7 @@ func (d Discriminator) FixedSize() (int, error) { } type DiscriminatorExtractor struct { - b64Index [256]byte + b64Index [128]byte } // NewDiscriminatorExtractor is optimised to extract discriminators from base64 encoded strings faster than the base64 lib. diff --git a/pkg/solana/codec/discriminator_extractor_test.go b/pkg/solana/codec/discriminator_extractor_test.go index 66459d791..f04788fed 100644 --- a/pkg/solana/codec/discriminator_extractor_test.go +++ b/pkg/solana/codec/discriminator_extractor_test.go @@ -13,10 +13,10 @@ func FuzzExtractorHappyPath(f *testing.F) { seeds := []struct { Data string }{ - {"SGVbG8gd29ybGQ"}, // "Hello world!" + {"SGVsbG8gV29ybGQh"}, // Hello world! {"AAAAAAAAAAAA"}, // Zero bytes {"////////////"}, // Max value bytes - {"QUJDREVGR0hJSktM"}, // "ABCDEFGHIJKL" + {"QUJDREVGR0hJSktM"}, // ABCDEFGHIJKL } for _, seed := range seeds { @@ -25,21 +25,12 @@ func FuzzExtractorHappyPath(f *testing.F) { extractor := NewDiscriminatorExtractor() f.Fuzz(func(t *testing.T, testString string) { - if len(testString) < 12 { - t.Fatalf("test string is shorter than 12 %s", testString) - } - - data := testString[:12] - std, err := base64.StdEncoding.DecodeString(data) + stdDecoded, err := base64.StdEncoding.DecodeString(testString) if err != nil { - t.Fatalf("failed to decode test string %s with stdlib", data) - } - - if len(std) < 8 { - t.Fatal("stdlib decoded < 8 bytes") + t.Fatalf("failed to decode test string %s with stdlib, err: %s", testString, err) } - require.Equal(t, std[:8], extractor.Extract(data)) + require.Equal(t, stdDecoded[:8], extractor.Extract(testString)) }) } From 4b3da30fa0d5a4aca5c5cf5273bbed324fdf8d4b Mon Sep 17 00:00:00 2001 From: ilija Date: Thu, 23 Jan 2025 18:20:33 +0100 Subject: [PATCH 6/8] Fix fuzz test for discriminator extractor --- pkg/solana/codec/discriminator.go | 3 ++- .../codec/discriminator_extractor_test.go | 17 +++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/pkg/solana/codec/discriminator.go b/pkg/solana/codec/discriminator.go index 65b850869..0046b1e21 100644 --- a/pkg/solana/codec/discriminator.go +++ b/pkg/solana/codec/discriminator.go @@ -95,7 +95,8 @@ func NewDiscriminatorExtractor() DiscriminatorExtractor { } // Extract most optimally (around 40% faster than std) decodes the first 8 bytes of a base64 encoded string, which corresponds to a Solana discriminator. -// Extractor expects input of > 12 characters which 8 bytes are extracted from, if the input string is less than 12 characters, this will panic. +// Extract expects input of > 12 characters which 8 bytes are extracted from, if the input string is less than 12 characters, this will panic. +// Extract doesn't handle base64 padding because discriminators shouldn't have padding. // If string contains non-Base64 characters (e.g., !, @, space) map to index 0 (ASCII 'A'), and won't be accurate. func (e *DiscriminatorExtractor) Extract(data string) []byte { var decodeBuffer [9]byte diff --git a/pkg/solana/codec/discriminator_extractor_test.go b/pkg/solana/codec/discriminator_extractor_test.go index f04788fed..32426fbee 100644 --- a/pkg/solana/codec/discriminator_extractor_test.go +++ b/pkg/solana/codec/discriminator_extractor_test.go @@ -3,6 +3,7 @@ package codec import ( "encoding/base64" mathrand "math/rand" + "strings" "testing" "github.com/stretchr/testify/require" @@ -25,12 +26,20 @@ func FuzzExtractorHappyPath(f *testing.F) { extractor := NewDiscriminatorExtractor() f.Fuzz(func(t *testing.T, testString string) { - stdDecoded, err := base64.StdEncoding.DecodeString(testString) - if err != nil { - t.Fatalf("failed to decode test string %s with stdlib, err: %s", testString, err) + // Extractor doesn't validate padding, newlines, or tabs + if len(testString) < 12 || + strings.Contains(testString, "\n") || + strings.Contains(testString, "\r") || + strings.Contains(testString, "\t") || + strings.HasSuffix(testString, "=") || + strings.HasSuffix(testString, "==") { + return } - require.Equal(t, stdDecoded[:8], extractor.Extract(testString)) + stdDecoded, err := base64.StdEncoding.DecodeString(testString) + if err == nil { + require.Equal(t, stdDecoded[:8], extractor.Extract(testString)) + } }) } From 98e23a0672ac121afeed5c288e55be7a6dcc3468 Mon Sep 17 00:00:00 2001 From: ilija Date: Fri, 24 Jan 2025 15:26:51 +0100 Subject: [PATCH 7/8] lint and logging --- pkg/solana/logpoller/filters.go | 17 +++++++++-------- pkg/solana/logpoller/log_poller.go | 9 +-------- pkg/solana/logpoller/models.go | 2 +- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/pkg/solana/logpoller/filters.go b/pkg/solana/logpoller/filters.go index 1648863be..975f6853b 100644 --- a/pkg/solana/logpoller/filters.go +++ b/pkg/solana/logpoller/filters.go @@ -31,7 +31,7 @@ type filters struct { filtersMutex sync.RWMutex loadedFilters atomic.Bool knownPrograms map[string]uint // fast lookup to see if a base58-encoded ProgramID matches any registered filters - knownDiscriminators map[string]uint // fast lookup by first 10 characters (60-bits) of a base64-encoded discriminator + knownDiscriminators map[string]uint // fast lookup based on raw discriminator bytes as string seqNums map[int64]int64 decoders map[int64]Decoder discriminatorExtractor codec.DiscriminatorExtractor @@ -231,13 +231,13 @@ func (fl *filters) removeFilterFromIndexes(filter Filter) { } } - discriminatorHead := filter.DiscriminatorRawBytes() - if refcount, ok := fl.knownDiscriminators[discriminatorHead]; ok { + discriminator := filter.DiscriminatorRawBytes() + if refcount, ok := fl.knownDiscriminators[discriminator]; ok { refcount-- if refcount > 0 { - fl.knownDiscriminators[discriminatorHead] = refcount + fl.knownDiscriminators[discriminator] = refcount } else { - delete(fl.knownDiscriminators, discriminatorHead) + delete(fl.knownDiscriminators, discriminator) } } } @@ -449,7 +449,7 @@ func (fl *filters) LoadFilters(ctx context.Context) error { // DecodeSubKey accepts raw Borsh-encoded event data, a filter ID and a subkeyPath. It uses the decoder // associated with that filter to decode the event and extract the subkey value from the specified subKeyPath. // WARNING: not thread safe, should only be called while fl.filtersMutex is held and after filters have been loaded. -func (fl *filters) DecodeSubKey(ctx context.Context, raw []byte, ID int64, subKeyPath []string) (any, error) { +func (fl *filters) DecodeSubKey(ctx context.Context, lggr logger.SugaredLogger, raw []byte, ID int64, subKeyPath []string) (any, error) { filter, ok := fl.filtersByID[ID] if !ok { return nil, fmt.Errorf("filter %d not found", ID) @@ -462,8 +462,9 @@ func (fl *filters) DecodeSubKey(ctx context.Context, raw []byte, ID int64, subKe if err != nil || decodedEvent == nil { return nil, err } - err = decoder.Decode(ctx, raw, decodedEvent, filter.EventName) - if err != nil { + if err = decoder.Decode(ctx, raw, decodedEvent, filter.EventName); err != nil { + err = fmt.Errorf("failed to decode sub key raw data: %v, for filter: %s, for subKeyPath: %v, err: %w", raw, subKeyPath, filter.Name, err) + lggr.Criticalw(err.Error()) return nil, err } return ExtractField(decodedEvent, subKeyPath) diff --git a/pkg/solana/logpoller/log_poller.go b/pkg/solana/logpoller/log_poller.go index 96893ff56..de294ea8b 100644 --- a/pkg/solana/logpoller/log_poller.go +++ b/pkg/solana/logpoller/log_poller.go @@ -157,16 +157,9 @@ func (lp *Service) Process(ctx context.Context, programEvent ProgramEvent) (err return err } - // TODO isn't discrimintaor already checked above in MatchingFiltersForEncodedEvent? - if len(log.Data) < 8 { - err = fmt.Errorf("assumption violation: %w, log.Data=%s", ErrMissingDiscriminator, log.Data) - lp.lggr.Criticalw(err.Error()) - return err - } - log.SubkeyValues = make([]IndexedValue, 0, len(filter.SubkeyPaths)) for _, path := range filter.SubkeyPaths { - subKeyVal, decodeSubKeyErr := lp.filters.DecodeSubKey(ctx, log.Data, filter.ID, path) + subKeyVal, decodeSubKeyErr := lp.filters.DecodeSubKey(ctx, lp.lggr, log.Data, filter.ID, path) if decodeSubKeyErr != nil { return decodeSubKeyErr } diff --git a/pkg/solana/logpoller/models.go b/pkg/solana/logpoller/models.go index e0b470348..0e7768d0f 100644 --- a/pkg/solana/logpoller/models.go +++ b/pkg/solana/logpoller/models.go @@ -26,7 +26,7 @@ func (f Filter) MatchSameLogs(other Filter) bool { f.EventIdl.Equal(other.EventIdl) && f.SubkeyPaths.Equal(other.SubkeyPaths) } -// DiscriminatorRawBytes returns raw discriminator bytes as a string, this string is not base64 encoded. +// DiscriminatorRawBytes returns raw discriminator bytes as a string, this string is not base64 encoded and is always len of discriminator which is 8. func (f Filter) DiscriminatorRawBytes() string { return string(codec.NewDiscriminatorHashPrefix(f.EventName, false)) } From 53174c0ddc74b6323ccf8f1e8802265f78136474 Mon Sep 17 00:00:00 2001 From: ilija Date: Fri, 24 Jan 2025 15:38:15 +0100 Subject: [PATCH 8/8] Change filtersI DecodeSubKey to include logger and run make generate --- pkg/solana/logpoller/log_poller.go | 2 +- pkg/solana/logpoller/mock_filters.go | 30 +++++++++++++++------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/pkg/solana/logpoller/log_poller.go b/pkg/solana/logpoller/log_poller.go index de294ea8b..f7f058eb7 100644 --- a/pkg/solana/logpoller/log_poller.go +++ b/pkg/solana/logpoller/log_poller.go @@ -46,7 +46,7 @@ type filtersI interface { GetFiltersToBackfill() []Filter MarkFilterBackfilled(ctx context.Context, filterID int64) error MatchingFiltersForEncodedEvent(event ProgramEvent) iter.Seq[Filter] - DecodeSubKey(ctx context.Context, raw []byte, ID int64, subKeyPath []string) (any, error) + DecodeSubKey(ctx context.Context, lggr logger.SugaredLogger, raw []byte, ID int64, subKeyPath []string) (any, error) IncrementSeqNum(filterID int64) int64 } diff --git a/pkg/solana/logpoller/mock_filters.go b/pkg/solana/logpoller/mock_filters.go index 29ad41632..96b14446c 100644 --- a/pkg/solana/logpoller/mock_filters.go +++ b/pkg/solana/logpoller/mock_filters.go @@ -6,6 +6,7 @@ import ( context "context" iter "iter" + logger "github.com/smartcontractkit/chainlink-common/pkg/logger" mock "github.com/stretchr/testify/mock" ) @@ -22,9 +23,9 @@ func (_m *mockFilters) EXPECT() *mockFilters_Expecter { return &mockFilters_Expecter{mock: &_m.Mock} } -// DecodeSubKey provides a mock function with given fields: ctx, raw, ID, subKeyPath -func (_m *mockFilters) DecodeSubKey(ctx context.Context, raw []byte, ID int64, subKeyPath []string) (interface{}, error) { - ret := _m.Called(ctx, raw, ID, subKeyPath) +// DecodeSubKey provides a mock function with given fields: ctx, lggr, raw, ID, subKeyPath +func (_m *mockFilters) DecodeSubKey(ctx context.Context, lggr logger.SugaredLogger, raw []byte, ID int64, subKeyPath []string) (interface{}, error) { + ret := _m.Called(ctx, lggr, raw, ID, subKeyPath) if len(ret) == 0 { panic("no return value specified for DecodeSubKey") @@ -32,19 +33,19 @@ func (_m *mockFilters) DecodeSubKey(ctx context.Context, raw []byte, ID int64, s var r0 interface{} var r1 error - if rf, ok := ret.Get(0).(func(context.Context, []byte, int64, []string) (interface{}, error)); ok { - return rf(ctx, raw, ID, subKeyPath) + if rf, ok := ret.Get(0).(func(context.Context, logger.SugaredLogger, []byte, int64, []string) (interface{}, error)); ok { + return rf(ctx, lggr, raw, ID, subKeyPath) } - if rf, ok := ret.Get(0).(func(context.Context, []byte, int64, []string) interface{}); ok { - r0 = rf(ctx, raw, ID, subKeyPath) + if rf, ok := ret.Get(0).(func(context.Context, logger.SugaredLogger, []byte, int64, []string) interface{}); ok { + r0 = rf(ctx, lggr, raw, ID, subKeyPath) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(interface{}) } } - if rf, ok := ret.Get(1).(func(context.Context, []byte, int64, []string) error); ok { - r1 = rf(ctx, raw, ID, subKeyPath) + if rf, ok := ret.Get(1).(func(context.Context, logger.SugaredLogger, []byte, int64, []string) error); ok { + r1 = rf(ctx, lggr, raw, ID, subKeyPath) } else { r1 = ret.Error(1) } @@ -59,16 +60,17 @@ type mockFilters_DecodeSubKey_Call struct { // DecodeSubKey is a helper method to define mock.On call // - ctx context.Context +// - lggr logger.SugaredLogger // - raw []byte // - ID int64 // - subKeyPath []string -func (_e *mockFilters_Expecter) DecodeSubKey(ctx interface{}, raw interface{}, ID interface{}, subKeyPath interface{}) *mockFilters_DecodeSubKey_Call { - return &mockFilters_DecodeSubKey_Call{Call: _e.mock.On("DecodeSubKey", ctx, raw, ID, subKeyPath)} +func (_e *mockFilters_Expecter) DecodeSubKey(ctx interface{}, lggr interface{}, raw interface{}, ID interface{}, subKeyPath interface{}) *mockFilters_DecodeSubKey_Call { + return &mockFilters_DecodeSubKey_Call{Call: _e.mock.On("DecodeSubKey", ctx, lggr, raw, ID, subKeyPath)} } -func (_c *mockFilters_DecodeSubKey_Call) Run(run func(ctx context.Context, raw []byte, ID int64, subKeyPath []string)) *mockFilters_DecodeSubKey_Call { +func (_c *mockFilters_DecodeSubKey_Call) Run(run func(ctx context.Context, lggr logger.SugaredLogger, raw []byte, ID int64, subKeyPath []string)) *mockFilters_DecodeSubKey_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].([]byte), args[2].(int64), args[3].([]string)) + run(args[0].(context.Context), args[1].(logger.SugaredLogger), args[2].([]byte), args[3].(int64), args[4].([]string)) }) return _c } @@ -78,7 +80,7 @@ func (_c *mockFilters_DecodeSubKey_Call) Return(_a0 interface{}, _a1 error) *moc return _c } -func (_c *mockFilters_DecodeSubKey_Call) RunAndReturn(run func(context.Context, []byte, int64, []string) (interface{}, error)) *mockFilters_DecodeSubKey_Call { +func (_c *mockFilters_DecodeSubKey_Call) RunAndReturn(run func(context.Context, logger.SugaredLogger, []byte, int64, []string) (interface{}, error)) *mockFilters_DecodeSubKey_Call { _c.Call.Return(run) return _c }