From 74b7a72c2706857920cc1e01b85fd4722559c798 Mon Sep 17 00:00:00 2001 From: Silas Lenihan <32529249+silaslenihan@users.noreply.github.com> Date: Mon, 13 Jan 2025 20:05:32 -0500 Subject: [PATCH] Updated getSeedBytes to handle array seeds (#1007) --- .../relayinterface/lookups_test.go | 83 +++++++++++- pkg/solana/chainwriter/helpers.go | 19 ++- pkg/solana/chainwriter/lookups.go | 100 +++++++++----- pkg/solana/chainwriter/testContractIDL.json | 124 +++++++++++++++++- 4 files changed, 288 insertions(+), 38 deletions(-) diff --git a/integration-tests/relayinterface/lookups_test.go b/integration-tests/relayinterface/lookups_test.go index fd148abff..0154b683a 100644 --- a/integration-tests/relayinterface/lookups_test.go +++ b/integration-tests/relayinterface/lookups_test.go @@ -127,6 +127,7 @@ func TestAccountLookups(t *testing.T) { func TestPDALookups(t *testing.T) { programID := chainwriter.GetRandomPubKey(t) + ctx := tests.Context(t) t.Run("PDALookup resolves valid PDA with constant address seeds", func(t *testing.T) { seed := chainwriter.GetRandomPubKey(t) @@ -152,7 +153,6 @@ func TestPDALookups(t *testing.T) { IsWritable: true, } - ctx := tests.Context(t) result, err := pdaLookup.Resolve(ctx, nil, nil, nil) require.NoError(t, err) require.Equal(t, expectedMeta, result) @@ -183,7 +183,6 @@ func TestPDALookups(t *testing.T) { IsWritable: true, } - ctx := tests.Context(t) args := map[string]interface{}{ "test_seed": seed1, "another_seed": seed2, @@ -205,7 +204,6 @@ func TestPDALookups(t *testing.T) { IsWritable: true, } - ctx := tests.Context(t) args := map[string]interface{}{ "test_seed": []byte("data"), } @@ -241,7 +239,6 @@ func TestPDALookups(t *testing.T) { IsWritable: true, } - ctx := tests.Context(t) args := map[string]interface{}{ "test_seed": seed1, "another_seed": seed2, @@ -251,6 +248,84 @@ func TestPDALookups(t *testing.T) { require.NoError(t, err) require.Equal(t, expectedMeta, result) }) + + t.Run("PDALookups resolves list of PDAs when a seed is an array", func(t *testing.T) { + singleSeed := []byte("test_seed") + arraySeed := []solana.PublicKey{chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)} + + expectedMeta := []*solana.AccountMeta{} + + for _, seed := range arraySeed { + pda, _, err := solana.FindProgramAddress([][]byte{singleSeed, seed.Bytes()}, programID) + require.NoError(t, err) + meta := &solana.AccountMeta{ + PublicKey: pda, + IsSigner: false, + IsWritable: false, + } + expectedMeta = append(expectedMeta, meta) + } + + pdaLookup := chainwriter.PDALookups{ + Name: "TestPDA", + PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()}, + Seeds: []chainwriter.Seed{ + {Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "single_seed"}}, + {Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "array_seed"}}, + }, + IsSigner: false, + IsWritable: false, + } + + args := map[string]interface{}{ + "single_seed": singleSeed, + "array_seed": arraySeed, + } + + result, err := pdaLookup.Resolve(ctx, args, nil, nil) + require.NoError(t, err) + require.Equal(t, expectedMeta, result) + }) + + t.Run("PDALookups resolves list of PDAs when multiple seeds are arrays", func(t *testing.T) { + arraySeed1 := [][]byte{[]byte("test_seed1"), []byte("test_seed2")} + arraySeed2 := []solana.PublicKey{chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)} + + expectedMeta := []*solana.AccountMeta{} + + for _, seed1 := range arraySeed1 { + for _, seed2 := range arraySeed2 { + pda, _, err := solana.FindProgramAddress([][]byte{seed1, seed2.Bytes()}, programID) + require.NoError(t, err) + meta := &solana.AccountMeta{ + PublicKey: pda, + IsSigner: false, + IsWritable: false, + } + expectedMeta = append(expectedMeta, meta) + } + } + + pdaLookup := chainwriter.PDALookups{ + Name: "TestPDA", + PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()}, + Seeds: []chainwriter.Seed{ + {Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}}, + {Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}}, + }, + IsSigner: false, + IsWritable: false, + } + + args := map[string]interface{}{ + "seed1": arraySeed1, + "seed2": arraySeed2, + } + + result, err := pdaLookup.Resolve(ctx, args, nil, nil) + require.NoError(t, err) + require.Equal(t, expectedMeta, result) + }) } func TestLookupTables(t *testing.T) { diff --git a/pkg/solana/chainwriter/helpers.go b/pkg/solana/chainwriter/helpers.go index a4b18e4d5..7d146a25a 100644 --- a/pkg/solana/chainwriter/helpers.go +++ b/pkg/solana/chainwriter/helpers.go @@ -84,6 +84,22 @@ func errorWithDebugID(err error, debugID string) error { // traversePath recursively traverses the given structure based on the provided path. func traversePath(data any, path []string) ([]any, error) { if len(path) == 0 { + val := reflect.ValueOf(data) + + // If the final data is a slice or array, flatten it into multiple items, + if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { + // don't flatten []byte + if val.Type().Elem().Kind() == reflect.Uint8 { + return []any{val.Interface()}, nil + } + + var results []any + for i := 0; i < val.Len(); i++ { + results = append(results, val.Index(i).Interface()) + } + return results, nil + } + // Otherwise, return just this one item return []any{data}, nil } @@ -124,9 +140,6 @@ func traversePath(data any, path []string) ([]any, error) { } return traversePath(value.Interface(), path[1:]) default: - if len(path) == 1 && val.Kind() == reflect.Slice && val.Type().Elem().Kind() == reflect.Uint8 { - return []any{val.Interface()}, nil - } return nil, errors.New("unexpected type encountered at path: " + path[0]) } } diff --git a/pkg/solana/chainwriter/lookups.go b/pkg/solana/chainwriter/lookups.go index f875ade3a..349ad74bc 100644 --- a/pkg/solana/chainwriter/lookups.go +++ b/pkg/solana/chainwriter/lookups.go @@ -142,7 +142,7 @@ func (pda PDALookups) Resolve(ctx context.Context, args any, derivedTableMap map return nil, fmt.Errorf("error getting public key for PDALookups: %w", err) } - seeds, err := getSeedBytes(ctx, pda, args, derivedTableMap, reader) + seeds, err := getSeedBytesCombinations(ctx, pda, args, derivedTableMap, reader) if err != nil { return nil, fmt.Errorf("error getting seeds for PDALookups: %w", err) } @@ -209,29 +209,43 @@ func decodeBorshIntoType(data []byte, typ reflect.Type) (interface{}, error) { return reflect.ValueOf(instance).Elem().Interface(), nil } -// getSeedBytes extracts the seeds for the PDALookups. -// It handles both AddressSeeds (which are public keys) and ValueSeeds (which are byte arrays from input args). -func getSeedBytes(ctx context.Context, lookup PDALookups, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader) ([][]byte, error) { - var seedBytes [][]byte +// getSeedBytesCombinations extracts the seeds for the PDALookups. +// The return type is [][][]byte, where each element of the outer slice is +// one combination of seeds. This handles the case where one seed can resolve +// to multiple addresses, multiplying the combinations accordingly. +func getSeedBytesCombinations( + ctx context.Context, + lookup PDALookups, + args any, + derivedTableMap map[string]map[string][]*solana.AccountMeta, + reader client.Reader, +) ([][][]byte, error) { + allCombinations := [][][]byte{ + {}, + } + // For each seed in the definition, expand the current list of combinations + // by all possible values for this seed. for _, seed := range lookup.Seeds { + expansions := make([][]byte, 0) if seed.Static != nil { - seedBytes = append(seedBytes, seed.Static) - } - if seed.Dynamic != nil { + expansions = append(expansions, seed.Static) + // Static and Dynamic seeds are mutually exclusive + } else if seed.Dynamic != nil { dynamicSeed := seed.Dynamic if lookupSeed, ok := dynamicSeed.(AccountLookup); ok { // Get value from a location (This doens't have to be an address, it can be any value) bytes, err := GetValuesAtLocation(args, lookupSeed.Location) if err != nil { - return nil, fmt.Errorf("error getting address seed: %w", err) + return nil, fmt.Errorf("error getting address seed for location %q: %w", lookupSeed.Location, err) } - // validate seed length + // append each byte array to the expansions for _, b := range bytes { + // validate seed length if len(b) > solana.MaxSeedLength { return nil, fmt.Errorf("seed byte array exceeds maximum length of %d: got %d bytes", solana.MaxSeedLength, len(b)) } - seedBytes = append(seedBytes, b) + expansions = append(expansions, b) } } else { // Get address seeds from the lookup @@ -239,34 +253,60 @@ func getSeedBytes(ctx context.Context, lookup PDALookups, args any, derivedTable if err != nil { return nil, fmt.Errorf("error getting address seed: %w", err) } - - // Add each address seed as bytes - for _, address := range seedAddresses { - seedBytes = append(seedBytes, address.PublicKey.Bytes()) + // Add each address seed to the expansions + for _, addrMeta := range seedAddresses { + b := addrMeta.PublicKey.Bytes() + if len(b) > solana.MaxSeedLength { + return nil, fmt.Errorf("seed byte array exceeds maximum length of %d: got %d bytes", solana.MaxSeedLength, len(b)) + } + expansions = append(expansions, b) } } } + + // expansions is the list of possible seed bytes for this single seed lookup. + // Multiply the existing combinations in allCombinations by each item in expansions. + newCombinations := make([][][]byte, 0, len(allCombinations)*len(expansions)) + for _, existingCombo := range allCombinations { + for _, expandedSeed := range expansions { + comboCopy := make([][]byte, len(existingCombo)+1) + copy(comboCopy, existingCombo) + comboCopy[len(existingCombo)] = expandedSeed + newCombinations = append(newCombinations, comboCopy) + } + } + + allCombinations = newCombinations } - return seedBytes, nil + return allCombinations, nil } // generatePDAs generates program-derived addresses (PDAs) from public keys and seeds. -func generatePDAs(publicKeys []*solana.AccountMeta, seeds [][]byte, lookup PDALookups) ([]*solana.AccountMeta, error) { - if len(seeds) > solana.MaxSeeds { - return nil, fmt.Errorf("seed maximum exceeded: %d", len(seeds)) - } - var addresses []*solana.AccountMeta +// it will result in a list of PDAs whose length is the product of the number of public keys +// and the number of seed combinations. +func generatePDAs( + publicKeys []*solana.AccountMeta, + seedCombos [][][]byte, + lookup PDALookups, +) ([]*solana.AccountMeta, error) { + var results []*solana.AccountMeta for _, publicKeyMeta := range publicKeys { - address, _, err := solana.FindProgramAddress(seeds, publicKeyMeta.PublicKey) - if err != nil { - return nil, fmt.Errorf("error finding program address: %w", err) + for _, combo := range seedCombos { + if len(combo) > solana.MaxSeeds { + return nil, fmt.Errorf("seed maximum exceeded: %d", len(combo)) + } + address, _, err := solana.FindProgramAddress(combo, publicKeyMeta.PublicKey) + if err != nil { + return nil, fmt.Errorf("error finding program address: %w", err) + } + results = append(results, &solana.AccountMeta{ + PublicKey: address, + IsSigner: lookup.IsSigner, + IsWritable: lookup.IsWritable, + }) } - addresses = append(addresses, &solana.AccountMeta{ - PublicKey: address, - IsSigner: lookup.IsSigner, - IsWritable: lookup.IsWritable, - }) } - return addresses, nil + + return results, nil } diff --git a/pkg/solana/chainwriter/testContractIDL.json b/pkg/solana/chainwriter/testContractIDL.json index 9631e4acf..531a8c172 100644 --- a/pkg/solana/chainwriter/testContractIDL.json +++ b/pkg/solana/chainwriter/testContractIDL.json @@ -1 +1,123 @@ -{"version":"0.1.0","name":"contractReaderInterface","instructions":[{"name":"initialize","accounts":[{"name":"data","isMut":true,"isSigner":false},{"name":"signer","isMut":true,"isSigner":true},{"name":"systemProgram","isMut":false,"isSigner":false}],"args":[{"name":"testIdx","type":"u64"},{"name":"value","type":"u64"}]},{"name":"initializeLookupTable","accounts":[{"name":"writeDataAccount","isMut":true,"isSigner":false,"docs":["PDA for LookupTableDataAccount, derived from seeds and created by the System Program"]},{"name":"admin","isMut":true,"isSigner":true,"docs":["Admin account that pays for PDA creation and signs the transaction"]},{"name":"systemProgram","isMut":false,"isSigner":false,"docs":["System Program required for PDA creation"]}],"args":[{"name":"lookupTable","type":"publicKey"}]}],"accounts":[{"name":"LookupTableDataAccount","type":{"kind":"struct","fields":[{"name":"version","type":"u8"},{"name":"administrator","type":"publicKey"},{"name":"pendingAdministrator","type":"publicKey"},{"name":"lookupTable","type":"publicKey"}]}},{"name":"DataAccount","type":{"kind":"struct","fields":[{"name":"idx","type":"u64"},{"name":"bump","type":"u8"},{"name":"u64Value","type":"u64"},{"name":"u64Slice","type":{"vec":"u64"}}]}}]} \ No newline at end of file +{ + "version": "0.1.0", + "name": "contractReaderInterface", + "instructions": [ + { + "name": "initialize", + "accounts": [ + { + "name": "data", + "isMut": true, + "isSigner": false + }, + { + "name": "signer", + "isMut": true, + "isSigner": true + }, + { + "name": "systemProgram", + "isMut": false, + "isSigner": false + } + ], + "args": [ + { + "name": "testIdx", + "type": "u64" + }, + { + "name": "value", + "type": "u64" + } + ] + }, + { + "name": "initializeLookupTable", + "accounts": [ + { + "name": "writeDataAccount", + "isMut": true, + "isSigner": false, + "docs": [ + "PDA for LookupTableDataAccount, derived from seeds and created by the System Program" + ] + }, + { + "name": "admin", + "isMut": true, + "isSigner": true, + "docs": [ + "Admin account that pays for PDA creation and signs the transaction" + ] + }, + { + "name": "systemProgram", + "isMut": false, + "isSigner": false, + "docs": [ + "System Program required for PDA creation" + ] + } + ], + "args": [ + { + "name": "lookupTable", + "type": "publicKey" + } + ] + } + ], + "accounts": [ + { + "name": "LookupTableDataAccount", + "type": { + "kind": "struct", + "fields": [ + { + "name": "version", + "type": "u8" + }, + { + "name": "administrator", + "type": "publicKey" + }, + { + "name": "pendingAdministrator", + "type": "publicKey" + }, + { + "name": "lookupTable", + "type": "publicKey" + } + ] + } + }, + { + "name": "DataAccount", + "type": { + "kind": "struct", + "fields": [ + { + "name": "idx", + "type": "u64" + }, + { + "name": "bump", + "type": "u8" + }, + { + "name": "u64Value", + "type": "u64" + }, + { + "name": "u64Slice", + "type": { + "vec": "u64" + } + } + ] + } + } + ] +} \ No newline at end of file