Skip to content

Commit

Permalink
Enabled automatic ATA creation in CW
Browse files Browse the repository at this point in the history
  • Loading branch information
silaslenihan committed Jan 31, 2025
1 parent aa71d84 commit ed183d4
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 263 deletions.
177 changes: 177 additions & 0 deletions integration-tests/relayinterface/lookups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/gagliardetto/solana-go/rpc"
"github.com/stretchr/testify/require"

"github.com/smartcontractkit/chainlink-ccip/chains/solana/utils/tokens"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
commonutils "github.com/smartcontractkit/chainlink-common/pkg/utils"
"github.com/smartcontractkit/chainlink-common/pkg/utils/tests"
Expand Down Expand Up @@ -482,6 +483,7 @@ func TestLookupTables(t *testing.T) {
txm := txm.NewTxm("localnet", loader, nil, cfg, mkey, lggr)

cw, err := chainwriter.NewSolanaChainWriterService(nil, solanaClient, txm, nil, chainwriter.ChainWriterConfig{})
require.NoError(t, err)

t.Run("StaticLookup table resolves properly", func(t *testing.T) {
pubKeys := chainwriter.CreateTestPubKeys(t, 8)
Expand Down Expand Up @@ -635,3 +637,178 @@ func TestLookupTables(t *testing.T) {
}
})
}

func TestCreateATAs(t *testing.T) {
ctx := tests.Context(t)

sender, err := solana.NewRandomPrivateKey()
require.NoError(t, err)

feePayer := sender.PublicKey()

url, _ := utils.SetupTestValidatorWithAnchorPrograms(t, sender.PublicKey().String(), []string{"contract-reader-interface"})
rpcClient := rpc.New(url)

utils.FundAccounts(t, []solana.PrivateKey{sender}, rpcClient)

cfg := config.NewDefault()
solanaClient, err := client.NewClient(url, cfg, 5*time.Second, nil)
require.NoError(t, err)

t.Run("returns no instructions when no ATA location is found", func(t *testing.T) {
lookups := []chainwriter.ATALookup{
{
Location: "Invalid.Address",
WalletAddress: chainwriter.AccountConstant{
Address: feePayer.String(),
},
TokenProgram: chainwriter.AccountConstant{
Address: solana.Token2022ProgramID.String(),
},
MintAddress: chainwriter.AccountLookup{
Location: "Invalid.Address",
},
},
}

args := chainwriter.TestArgs{
Inner: []chainwriter.InnerArgs{
{Address: chainwriter.GetRandomPubKey(t).Bytes()},
},
}

ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
require.NoError(t, err)
require.Empty(t, ataInstructions)
})

t.Run("fails with multiple wallet addresses", func(t *testing.T) {
lookups := []chainwriter.ATALookup{
{
Location: "",
WalletAddress: chainwriter.AccountLookup{
Location: "Addresses",
},
TokenProgram: chainwriter.AccountConstant{
Address: solana.Token2022ProgramID.String(),
},
MintAddress: chainwriter.AccountConstant{
Address: chainwriter.GetRandomPubKey(t).String(),
},
},
}

args := map[string][]solana.PublicKey{
"Addresses": {chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)},
}

_, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
require.Contains(t, err.Error(), "expected exactly one wallet address, got 2")
})

t.Run("fails with mismatched mint and token programs", func(t *testing.T) {
lookups := []chainwriter.ATALookup{
{
Location: "",
WalletAddress: chainwriter.AccountConstant{
Address: feePayer.String(),
},
TokenProgram: chainwriter.AccountConstant{
Address: solana.Token2022ProgramID.String(),
},
MintAddress: chainwriter.AccountLookup{
Location: "Addresses",
},
},
}

args := map[string][]solana.PublicKey{
"Addresses": {chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)},
}

_, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
require.Contains(t, err.Error(), "expected equal number of token programs and mints, got 1 tokenPrograms and 2 mints")
})

t.Run("fails when mint is not a token address", func(t *testing.T) {
tokenProgram := solana.Token2022ProgramID
mint := chainwriter.GetRandomPubKey(t)

ataAddress, _, err := tokens.FindAssociatedTokenAddress(tokenProgram, mint, feePayer)
require.NoError(t, err)
require.False(t, checkIfATAExists(t, rpcClient, ataAddress))
lookups := []chainwriter.ATALookup{
{
Location: "Inner.Address",
WalletAddress: chainwriter.AccountConstant{
Address: feePayer.String(),
},
TokenProgram: chainwriter.AccountConstant{
Address: tokenProgram.String(),
},
MintAddress: chainwriter.AccountLookup{
Location: "Inner.Address",
},
},
}

args := chainwriter.TestArgs{
Inner: []chainwriter.InnerArgs{
{Address: mint.Bytes()},
},
}

ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
require.NoError(t, err)

tx := utils.CreateTx(ctx, t, rpcClient, ataInstructions, sender, rpc.CommitmentFinalized)

_, err = rpcClient.SendTransactionWithOpts(ctx, tx, rpc.TransactionOpts{SkipPreflight: false, PreflightCommitment: rpc.CommitmentProcessed})
require.Contains(t, err.Error(), "Program log: Error: Invalid Mint")
})

t.Run("successfully creates ATAs only when necessary", func(t *testing.T) {
tokenProgram := solana.Token2022ProgramID
mint := utils.CreateRandomToken(t, sender, rpcClient)

ataAddress, _, err := tokens.FindAssociatedTokenAddress(tokenProgram, mint, feePayer)
require.NoError(t, err)
require.False(t, checkIfATAExists(t, rpcClient, ataAddress))
lookups := []chainwriter.ATALookup{
{
Location: "Inner.Address",
WalletAddress: chainwriter.AccountConstant{
Address: feePayer.String(),
},
TokenProgram: chainwriter.AccountConstant{
Address: tokenProgram.String(),
},
MintAddress: chainwriter.AccountLookup{
Location: "Inner.Address",
},
},
}

args := chainwriter.TestArgs{
Inner: []chainwriter.InnerArgs{
{Address: mint.Bytes()},
},
}

ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
require.NoError(t, err)

utils.SendAndConfirm(ctx, t, rpcClient, ataInstructions, sender, rpc.CommitmentFinalized)
require.True(t, checkIfATAExists(t, rpcClient, ataAddress))

// now, if we try to create the same ATA again, it should return no instructions
ataInstructions, err = chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
require.NoError(t, err)
require.Empty(t, ataInstructions)
})
}

func checkIfATAExists(t *testing.T, rpcClient *rpc.Client, ataAddress solana.PublicKey) bool {
_, err := rpcClient.GetAccountInfo(tests.Context(t), ataAddress)
return err == nil
}
90 changes: 87 additions & 3 deletions pkg/solana/chainwriter/chain_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package chainwriter
import (
"context"
"encoding/json"
"errors"
"fmt"
"math/big"
"strings"

"github.com/gagliardetto/solana-go"
addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table"
"github.com/gagliardetto/solana-go/rpc"

"github.com/smartcontractkit/chainlink-ccip/chains/solana/utils/tokens"
commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/services"
Expand Down Expand Up @@ -55,6 +58,7 @@ type MethodConfig struct {
FromAddress string
InputModifications commoncodec.ModifiersConfig
ChainSpecificName string
ATAs []ATALookup
LookupTables LookupTables
Accounts []Lookup
// Location in the args where the debug ID is stored
Expand Down Expand Up @@ -214,6 +218,78 @@ func (s *SolanaChainWriterService) FilterLookupTableAddresses(
return filteredLookupTables
}

// CreateATAs first checks if a specified location exists, then checks if the accounts derived from the
// ATALookups in the ChainWriter's configuration exist on-chain and creates them if they do not.
func CreateATAs(ctx context.Context, args any, lookups []ATALookup, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader, idl string, feePayer solana.PublicKey) ([]solana.Instruction, error) {
createATAInstructions := []solana.Instruction{}
for _, lookup := range lookups {
// Check if location exists
if lookup.Location != "" {
// TODO refactor GetValuesAtLocation to not return an error if the field doesn't exist
_, err := GetValuesAtLocation(args, lookup.Location)
if err != nil {
// field doesn't exist, so ignore ATA creation
if errors.Is(err, errFieldNotFound) {
continue
}
return nil, fmt.Errorf("error getting values at location: %w", err)
}
}
walletAddresses, err := GetAddresses(ctx, args, []Lookup{lookup.WalletAddress}, derivedTableMap, reader, idl)
if err != nil {
return nil, fmt.Errorf("error resolving wallet address: %w", err)
}
if len(walletAddresses) != 1 {
return nil, fmt.Errorf("expected exactly one wallet address, got %d", len(walletAddresses))
}
wallet := walletAddresses[0].PublicKey

tokenPrograms, err := GetAddresses(ctx, args, []Lookup{lookup.TokenProgram}, derivedTableMap, reader, idl)
if err != nil {
return nil, fmt.Errorf("error resolving token program address: %w", err)
}

mints, err := GetAddresses(ctx, args, []Lookup{lookup.MintAddress}, derivedTableMap, reader, idl)
if err != nil {
return nil, fmt.Errorf("error resolving mint address: %w", err)
}

// Not sure if
if len(tokenPrograms) != len(mints) {
return nil, fmt.Errorf("expected equal number of token programs and mints, got %d tokenPrograms and %d mints", len(tokenPrograms), len(mints))
}

for i := range tokenPrograms {
tokenProgram := tokenPrograms[i].PublicKey
mint := mints[i].PublicKey

ataAddress, _, err := tokens.FindAssociatedTokenAddress(tokenProgram, mint, wallet)
if err != nil {
return nil, fmt.Errorf("error deriving ATA: %w", err)
}

_, err = reader.GetAccountInfoWithOpts(ctx, ataAddress, &rpc.GetAccountInfoOpts{
Encoding: "base64",
Commitment: rpc.CommitmentFinalized,
})
if err == nil {
continue
}
if !strings.Contains(err.Error(), "not found") {
return nil, fmt.Errorf("error reading account info for ATA: %w", err)
}

ins, _, err := tokens.CreateAssociatedTokenAccount(tokenProgram, mint, wallet, feePayer)
if err != nil {
return nil, fmt.Errorf("error creating associated token account: %w", err)
}
createATAInstructions = append(createATAInstructions, ins)
}
}

return createATAInstructions, nil
}

// SubmitTransaction builds, encodes, and enqueues a transaction using the provided program
// configuration and method details. It relies on the configured IDL, account lookups, and
// lookup tables to gather the necessary accounts and data. The function retrieves the latest
Expand Down Expand Up @@ -274,6 +350,11 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
return errorWithDebugID(fmt.Errorf("error parsing fee payer address: %w", err), debugID)
}

createATAinstructions, err := CreateATAs(ctx, args, methodConfig.ATAs, derivedTableMap, s.reader, programConfig.IDL, feePayer)
if err != nil {
return errorWithDebugID(fmt.Errorf("error resolving account addresses: %w", err), debugID)
}

// Filter the lookup table addresses based on which accounts are actually used
filteredLookupTableMap := s.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap)

Expand Down Expand Up @@ -310,10 +391,13 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
discriminator := GetDiscriminator(methodConfig.ChainSpecificName)
encodedPayload = append(discriminator[:], encodedPayload...)

// Combine the two sets of instructions into one slice
var instructions []solana.Instruction
instructions = append(instructions, createATAinstructions...)
instructions = append(instructions, solana.NewInstruction(programID, accounts, encodedPayload))

tx, err := solana.NewTransaction(
[]solana.Instruction{
solana.NewInstruction(programID, accounts, encodedPayload),
},
instructions,
blockhash.Value.Blockhash,
solana.TransactionPayer(feePayer),
solana.TransactionAddressTables(filteredLookupTableMap),
Expand Down
Loading

0 comments on commit ed183d4

Please sign in to comment.