Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve log poller codec usage #1014

Open
wants to merge 4 commits into
base: NONEVM-916-logpoller-process-decode
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions pkg/solana/codec/codec_entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,16 @@ func NewAccountEntry(offchainName string, idlTypes AccountIDLTypes, includeDiscr
return nil, err
}

var discriminator *Discriminator
if includeDiscriminator {
discriminator = NewDiscriminator(idlTypes.Account.Name, true)
}

return newEntry(
offchainName,
idlTypes.Account.Name,
accCodec,
includeDiscriminator,
discriminator,
mod,
), nil
}
Expand All @@ -69,7 +74,7 @@ func NewInstructionArgsEntry(offChainName string, idlTypes InstructionArgsIDLTyp
idlTypes.Instruction.Name,
instructionCodecArgs,
// Instruction arguments don't need a discriminator by default
false,
nil,
mod,
), nil
}
Expand All @@ -85,30 +90,40 @@ func NewEventArgsEntry(offChainName string, idlTypes EventIDLTypes, includeDiscr
return nil, err
}

var discriminator *Discriminator
if includeDiscriminator {
discriminator = NewDiscriminator(idlTypes.Event.Name, false)
}

return newEntry(
offChainName,
idlTypes.Event.Name,
eventCodec,
includeDiscriminator,
discriminator,
mod,
), nil
}

func newEntry(
genericName, chainSpecificName string,
typeCodec commonencodings.TypeCodec,
includeDiscriminator bool,
discriminator *Discriminator,
mod codec.Modifier,
) Entry {
return &entry{
genericName: genericName,
chainSpecificName: chainSpecificName,
reflectType: typeCodec.GetType(),
typeCodec: typeCodec,
mod: ensureModifier(mod),
includeDiscriminator: includeDiscriminator,
discriminator: *NewDiscriminator(chainSpecificName),
e := &entry{
genericName: genericName,
chainSpecificName: chainSpecificName,
reflectType: typeCodec.GetType(),
typeCodec: typeCodec,
mod: ensureModifier(mod),
}

if discriminator != nil {
e.discriminator = *discriminator
e.includeDiscriminator = true
}

return e
}

func createRefs(idlTypes IdlTypeDefSlice, builder commonencodings.Builder) *codecRefs {
Expand Down Expand Up @@ -159,8 +174,8 @@ func (e *entry) Decode(encoded []byte) (any, []byte, error) {
}

if !bytes.Equal(e.discriminator.hashPrefix, encoded[:discriminatorLength]) {
return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v for genericName: %q, chainSpecificName: %q",
commontypes.ErrInvalidType, encoded[:discriminatorLength], e.genericName, e.chainSpecificName)
return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v, expected %v, for genericName: %q, chainSpecificName: %q",
commontypes.ErrInvalidType, encoded[:discriminatorLength], e.discriminator.hashPrefix, e.genericName, e.chainSpecificName)
}

encoded = encoded[discriminatorLength:]
Expand Down
20 changes: 14 additions & 6 deletions pkg/solana/codec/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,26 @@ func (it *codecInterfaceTester) GetAccountString(i int) string {
}

func (it *codecInterfaceTester) EncodeFields(t *testing.T, request *EncodeRequest) []byte {
if request.TestOn == TestItemType || request.TestOn == testutils.TestEventItem {
return encodeFieldsOnItem(t, request)
if request.TestOn == TestItemType {
return encodeFieldsOnItem(t, request, true)
} else if request.TestOn == testutils.TestEventItem {
return encodeFieldsOnItem(t, request, false)
}

return encodeFieldsOnSliceOrArray(t, request)
}

func encodeFieldsOnItem(t *testing.T, request *EncodeRequest) ocr2types.Report {
func encodeFieldsOnItem(t *testing.T, request *EncodeRequest, isAccount bool) ocr2types.Report {
buf := new(bytes.Buffer)
// The underlying TestItemAsAccount adds a discriminator by default while being Borsh encoded.
if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
require.NoError(t, err)
// The underlying TestItem adds a discriminator by default while being Borsh encoded.
if isAccount {
if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
require.NoError(t, err)
}
} else {
if err := testutils.EncodeRequestToTestItemAsEvent(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
require.NoError(t, err)
}
}
return buf.Bytes()
}
Expand Down
16 changes: 13 additions & 3 deletions pkg/solana/codec/discriminator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,19 @@ import (

const discriminatorLength = 8

func NewDiscriminator(name string) *Discriminator {
sum := sha256.Sum256([]byte("account:" + name))
return &Discriminator{hashPrefix: sum[:discriminatorLength]}
func NewDiscriminator(name string, isAccount bool) *Discriminator {
return &Discriminator{hashPrefix: NewDiscriminatorHashPrefix(name, isAccount)}
}

func NewDiscriminatorHashPrefix(name string, isAccount bool) []byte {
var sum [32]byte
if isAccount {
sum = sha256.Sum256([]byte("account:" + name))
} else {
sum = sha256.Sum256([]byte("event:" + name))
}

return sum[:discriminatorLength]
}

type Discriminator struct {
Expand Down
18 changes: 9 additions & 9 deletions pkg/solana/codec/discriminator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestDiscriminator(t *testing.T) {
t.Run("encode and decode return the discriminator", func(t *testing.T) {
tmp := sha256.Sum256([]byte("account:Foo"))
expected := tmp[:8]
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
encoded, err := c.Encode(&expected, nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
Expand All @@ -28,15 +28,15 @@ func TestDiscriminator(t *testing.T) {
})

t.Run("encode returns an error if the discriminator is invalid", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, err := c.Encode(&[]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, nil)
require.True(t, errors.Is(err, types.ErrInvalidType))
})

t.Run("encode injects the discriminator if it's not provided", func(t *testing.T) {
tmp := sha256.Sum256([]byte("account:Foo"))
expected := tmp[:8]
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
encoded, err := c.Encode(nil, nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
Expand All @@ -46,37 +46,37 @@ func TestDiscriminator(t *testing.T) {
})

t.Run("decode returns an error if the encoded value is too short", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06})
require.True(t, errors.Is(err, types.ErrInvalidEncoding))
})

t.Run("decode returns an error if the discriminator is invalid", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})
require.True(t, errors.Is(err, types.ErrInvalidEncoding))
})

t.Run("encode returns an error if the value is not a byte slice", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
_, err := c.Encode(42, nil)
require.True(t, errors.Is(err, types.ErrInvalidType))
})

t.Run("GetType returns the type of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
require.Equal(t, reflect.TypeOf(&[]byte{}), c.GetType())
})

t.Run("Size returns the length of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
size, err := c.Size(0)
require.NoError(t, err)
require.Equal(t, 8, size)
})

t.Run("FixedSize returns the length of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
c := codec.NewDiscriminator("Foo", true)
size, err := c.FixedSize()
require.NoError(t, err)
require.Equal(t, 8, size)
Expand Down
61 changes: 5 additions & 56 deletions pkg/solana/codec/solana.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,32 +146,11 @@ func WrapItemType(forEncoding bool, contractName, itemType string, readType Chai
return fmt.Sprintf("output.%s.%s.%s", readType, contractName, itemType)
}

// NewIDLAccountCodec is for Anchor custom types
// TODO Deprecate and remove this.
func NewIDLAccountCodec(idl IDL, builder commonencodings.Builder) (commontypes.RemoteCodec, error) {
return newIDLCoded(idl, builder, idl.Accounts, true)
}

func NewIDLInstructionsCodec(idl IDL, builder commonencodings.Builder) (commontypes.RemoteCodec, error) {
typeCodecs := make(commonencodings.LenientCodecFromTypeCodec)
refs := &codecRefs{
builder: builder,
codecs: make(map[string]commonencodings.TypeCodec),
typeDefs: idl.Types,
dependencies: make(map[string][]string),
}

for _, instruction := range idl.Instructions {
name, instCodec, err := asStruct(instruction.Args, refs, instruction.Name, false, false)
if err != nil {
return nil, err
}

typeCodecs[name] = instCodec
}

return typeCodecs, nil
}

func NewNamedModifierCodec(original commontypes.RemoteCodec, itemType string, modifier commoncodec.Modifier) (commontypes.RemoteCodec, error) {
mod, err := commoncodec.NewByItemTypeModifier(map[string]commoncodec.Modifier{itemType: modifier})
if err != nil {
Expand All @@ -188,6 +167,7 @@ func NewNamedModifierCodec(original commontypes.RemoteCodec, itemType string, mo
return modCodec, err
}

// TODO Deprecate and remove this.
func NewIDLDefinedTypesCodec(idl IDL, builder commonencodings.Builder) (commontypes.RemoteCodec, error) {
return newIDLCoded(idl, builder, idl.Types, false)
}
Expand Down Expand Up @@ -252,6 +232,7 @@ type codecRefs struct {
func createCodecType(
def IdlTypeDef,
refs *codecRefs,
// TODO Deprecated includeDiscriminator is not needed here after NewIDLAccountCodec gets cleaned up
includeDiscriminator bool,
) (string, commonencodings.TypeCodec, error) {
name := def.Name
Expand All @@ -273,6 +254,7 @@ func asStruct(
fields []IdlField,
refs *codecRefs,
name string, // name is the struct name and can be used in dependency checks
// TODO Deprecated includeDiscriminator is not needed here after NewIDLAccountCodec gets cleaned up
includeDiscriminator bool,
isInstructionArgs bool,
) (string, commonencodings.TypeCodec, error) {
Expand All @@ -284,7 +266,7 @@ func asStruct(
named := make([]commonencodings.NamedTypeCodec, len(fields)+desLen)

if includeDiscriminator {
named[0] = commonencodings.NamedTypeCodec{Name: "Discriminator" + name, Codec: NewDiscriminator(name)}
named[0] = commonencodings.NamedTypeCodec{Name: "Discriminator" + name, Codec: NewDiscriminator(name, true)}
}

for idx, field := range fields {
Expand Down Expand Up @@ -489,36 +471,3 @@ func saveDependency(refs *codecRefs, parent, child string) {

refs.dependencies[parent] = append(deps, child)
}
func NewIDLEventCodec(idl IDL, builder commonencodings.Builder) (commontypes.RemoteCodec, error) {
typeCodecs := make(commonencodings.LenientCodecFromTypeCodec)
refs := &codecRefs{
builder: builder,
codecs: make(map[string]commonencodings.TypeCodec),
typeDefs: idl.Types,
dependencies: make(map[string][]string),
}

for _, event := range idl.Events {
name, instCodec, err := asStruct(eventFieldsAsStandardFields(event.Fields), refs, event.Name, false, false)
if err != nil {
return nil, err
}

typeCodecs[name] = instCodec
}

return typeCodecs, nil
}

func eventFieldsAsStandardFields(event []IdlEventField) []IdlField {
output := make([]IdlField, len(event))

for idx := range output {
output[idx] = IdlField{
Name: event[idx].Name,
Type: event[idx].Type,
}
}

return output
}
Loading
Loading