Skip to content

Commit

Permalink
Updated getSeedBytes to handle array seeds (#1007)
Browse files Browse the repository at this point in the history
  • Loading branch information
silaslenihan authored Jan 14, 2025
1 parent e2a9566 commit 74b7a72
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 38 deletions.
83 changes: 79 additions & 4 deletions integration-tests/relayinterface/lookups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ func TestAccountLookups(t *testing.T) {

func TestPDALookups(t *testing.T) {
programID := chainwriter.GetRandomPubKey(t)
ctx := tests.Context(t)

t.Run("PDALookup resolves valid PDA with constant address seeds", func(t *testing.T) {
seed := chainwriter.GetRandomPubKey(t)
Expand All @@ -152,7 +153,6 @@ func TestPDALookups(t *testing.T) {
IsWritable: true,
}

ctx := tests.Context(t)
result, err := pdaLookup.Resolve(ctx, nil, nil, nil)
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
Expand Down Expand Up @@ -183,7 +183,6 @@ func TestPDALookups(t *testing.T) {
IsWritable: true,
}

ctx := tests.Context(t)
args := map[string]interface{}{
"test_seed": seed1,
"another_seed": seed2,
Expand All @@ -205,7 +204,6 @@ func TestPDALookups(t *testing.T) {
IsWritable: true,
}

ctx := tests.Context(t)
args := map[string]interface{}{
"test_seed": []byte("data"),
}
Expand Down Expand Up @@ -241,7 +239,6 @@ func TestPDALookups(t *testing.T) {
IsWritable: true,
}

ctx := tests.Context(t)
args := map[string]interface{}{
"test_seed": seed1,
"another_seed": seed2,
Expand All @@ -251,6 +248,84 @@ func TestPDALookups(t *testing.T) {
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})

t.Run("PDALookups resolves list of PDAs when a seed is an array", func(t *testing.T) {
singleSeed := []byte("test_seed")
arraySeed := []solana.PublicKey{chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)}

expectedMeta := []*solana.AccountMeta{}

for _, seed := range arraySeed {
pda, _, err := solana.FindProgramAddress([][]byte{singleSeed, seed.Bytes()}, programID)
require.NoError(t, err)
meta := &solana.AccountMeta{
PublicKey: pda,
IsSigner: false,
IsWritable: false,
}
expectedMeta = append(expectedMeta, meta)
}

pdaLookup := chainwriter.PDALookups{
Name: "TestPDA",
PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()},
Seeds: []chainwriter.Seed{
{Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "single_seed"}},
{Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "array_seed"}},
},
IsSigner: false,
IsWritable: false,
}

args := map[string]interface{}{
"single_seed": singleSeed,
"array_seed": arraySeed,
}

result, err := pdaLookup.Resolve(ctx, args, nil, nil)
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})

t.Run("PDALookups resolves list of PDAs when multiple seeds are arrays", func(t *testing.T) {
arraySeed1 := [][]byte{[]byte("test_seed1"), []byte("test_seed2")}
arraySeed2 := []solana.PublicKey{chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)}

expectedMeta := []*solana.AccountMeta{}

for _, seed1 := range arraySeed1 {
for _, seed2 := range arraySeed2 {
pda, _, err := solana.FindProgramAddress([][]byte{seed1, seed2.Bytes()}, programID)
require.NoError(t, err)
meta := &solana.AccountMeta{
PublicKey: pda,
IsSigner: false,
IsWritable: false,
}
expectedMeta = append(expectedMeta, meta)
}
}

pdaLookup := chainwriter.PDALookups{
Name: "TestPDA",
PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()},
Seeds: []chainwriter.Seed{
{Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}},
{Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}},
},
IsSigner: false,
IsWritable: false,
}

args := map[string]interface{}{
"seed1": arraySeed1,
"seed2": arraySeed2,
}

result, err := pdaLookup.Resolve(ctx, args, nil, nil)
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})
}

func TestLookupTables(t *testing.T) {
Expand Down
19 changes: 16 additions & 3 deletions pkg/solana/chainwriter/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ func errorWithDebugID(err error, debugID string) error {
// traversePath recursively traverses the given structure based on the provided path.
func traversePath(data any, path []string) ([]any, error) {
if len(path) == 0 {
val := reflect.ValueOf(data)

// If the final data is a slice or array, flatten it into multiple items,
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
// don't flatten []byte
if val.Type().Elem().Kind() == reflect.Uint8 {
return []any{val.Interface()}, nil
}

var results []any
for i := 0; i < val.Len(); i++ {
results = append(results, val.Index(i).Interface())
}
return results, nil
}
// Otherwise, return just this one item
return []any{data}, nil
}

Expand Down Expand Up @@ -124,9 +140,6 @@ func traversePath(data any, path []string) ([]any, error) {
}
return traversePath(value.Interface(), path[1:])
default:
if len(path) == 1 && val.Kind() == reflect.Slice && val.Type().Elem().Kind() == reflect.Uint8 {
return []any{val.Interface()}, nil
}
return nil, errors.New("unexpected type encountered at path: " + path[0])
}
}
Expand Down
100 changes: 70 additions & 30 deletions pkg/solana/chainwriter/lookups.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (pda PDALookups) Resolve(ctx context.Context, args any, derivedTableMap map
return nil, fmt.Errorf("error getting public key for PDALookups: %w", err)
}

seeds, err := getSeedBytes(ctx, pda, args, derivedTableMap, reader)
seeds, err := getSeedBytesCombinations(ctx, pda, args, derivedTableMap, reader)
if err != nil {
return nil, fmt.Errorf("error getting seeds for PDALookups: %w", err)
}
Expand Down Expand Up @@ -209,64 +209,104 @@ func decodeBorshIntoType(data []byte, typ reflect.Type) (interface{}, error) {
return reflect.ValueOf(instance).Elem().Interface(), nil
}

// getSeedBytes extracts the seeds for the PDALookups.
// It handles both AddressSeeds (which are public keys) and ValueSeeds (which are byte arrays from input args).
func getSeedBytes(ctx context.Context, lookup PDALookups, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader) ([][]byte, error) {
var seedBytes [][]byte
// getSeedBytesCombinations extracts the seeds for the PDALookups.
// The return type is [][][]byte, where each element of the outer slice is
// one combination of seeds. This handles the case where one seed can resolve
// to multiple addresses, multiplying the combinations accordingly.
func getSeedBytesCombinations(
ctx context.Context,
lookup PDALookups,
args any,
derivedTableMap map[string]map[string][]*solana.AccountMeta,
reader client.Reader,
) ([][][]byte, error) {
allCombinations := [][][]byte{
{},
}

// For each seed in the definition, expand the current list of combinations
// by all possible values for this seed.
for _, seed := range lookup.Seeds {
expansions := make([][]byte, 0)
if seed.Static != nil {
seedBytes = append(seedBytes, seed.Static)
}
if seed.Dynamic != nil {
expansions = append(expansions, seed.Static)
// Static and Dynamic seeds are mutually exclusive
} else if seed.Dynamic != nil {
dynamicSeed := seed.Dynamic
if lookupSeed, ok := dynamicSeed.(AccountLookup); ok {
// Get value from a location (This doens't have to be an address, it can be any value)
bytes, err := GetValuesAtLocation(args, lookupSeed.Location)
if err != nil {
return nil, fmt.Errorf("error getting address seed: %w", err)
return nil, fmt.Errorf("error getting address seed for location %q: %w", lookupSeed.Location, err)
}
// validate seed length
// append each byte array to the expansions
for _, b := range bytes {
// validate seed length
if len(b) > solana.MaxSeedLength {
return nil, fmt.Errorf("seed byte array exceeds maximum length of %d: got %d bytes", solana.MaxSeedLength, len(b))
}
seedBytes = append(seedBytes, b)
expansions = append(expansions, b)
}
} else {
// Get address seeds from the lookup
seedAddresses, err := GetAddresses(ctx, args, []Lookup{dynamicSeed}, derivedTableMap, reader)
if err != nil {
return nil, fmt.Errorf("error getting address seed: %w", err)
}

// Add each address seed as bytes
for _, address := range seedAddresses {
seedBytes = append(seedBytes, address.PublicKey.Bytes())
// Add each address seed to the expansions
for _, addrMeta := range seedAddresses {
b := addrMeta.PublicKey.Bytes()
if len(b) > solana.MaxSeedLength {
return nil, fmt.Errorf("seed byte array exceeds maximum length of %d: got %d bytes", solana.MaxSeedLength, len(b))
}
expansions = append(expansions, b)
}
}
}

// expansions is the list of possible seed bytes for this single seed lookup.
// Multiply the existing combinations in allCombinations by each item in expansions.
newCombinations := make([][][]byte, 0, len(allCombinations)*len(expansions))
for _, existingCombo := range allCombinations {
for _, expandedSeed := range expansions {
comboCopy := make([][]byte, len(existingCombo)+1)
copy(comboCopy, existingCombo)
comboCopy[len(existingCombo)] = expandedSeed
newCombinations = append(newCombinations, comboCopy)
}
}

allCombinations = newCombinations
}

return seedBytes, nil
return allCombinations, nil
}

// generatePDAs generates program-derived addresses (PDAs) from public keys and seeds.
func generatePDAs(publicKeys []*solana.AccountMeta, seeds [][]byte, lookup PDALookups) ([]*solana.AccountMeta, error) {
if len(seeds) > solana.MaxSeeds {
return nil, fmt.Errorf("seed maximum exceeded: %d", len(seeds))
}
var addresses []*solana.AccountMeta
// it will result in a list of PDAs whose length is the product of the number of public keys
// and the number of seed combinations.
func generatePDAs(
publicKeys []*solana.AccountMeta,
seedCombos [][][]byte,
lookup PDALookups,
) ([]*solana.AccountMeta, error) {
var results []*solana.AccountMeta
for _, publicKeyMeta := range publicKeys {
address, _, err := solana.FindProgramAddress(seeds, publicKeyMeta.PublicKey)
if err != nil {
return nil, fmt.Errorf("error finding program address: %w", err)
for _, combo := range seedCombos {
if len(combo) > solana.MaxSeeds {
return nil, fmt.Errorf("seed maximum exceeded: %d", len(combo))
}
address, _, err := solana.FindProgramAddress(combo, publicKeyMeta.PublicKey)
if err != nil {
return nil, fmt.Errorf("error finding program address: %w", err)
}
results = append(results, &solana.AccountMeta{
PublicKey: address,
IsSigner: lookup.IsSigner,
IsWritable: lookup.IsWritable,
})
}
addresses = append(addresses, &solana.AccountMeta{
PublicKey: address,
IsSigner: lookup.IsSigner,
IsWritable: lookup.IsWritable,
})
}
return addresses, nil

return results, nil
}
Loading

0 comments on commit 74b7a72

Please sign in to comment.