From e2a9566d39f375b3420da28477e8bb65396d408f Mon Sep 17 00:00:00 2001 From: Juan Farber Date: Fri, 10 Jan 2025 11:25:50 -0300 Subject: [PATCH] [NONEVM-984][solana] - Reorg Detection + lighter rpc call (#951) * refactor so txm owns blockhash assignment * lastValidBlockHeight shouldn't be exported * better comment * refactor sendWithRetry to make it clearer * confirm loop refactor * fix infinite loop * move accountID inside msg * lint fix * base58 does not contain lower l * fix hash errors * fix generate random hash * remove blockhash as we only need block height * expired tx changes without tests * add maybe to mocks * expiration tests * send txes through queue * revert pendingtx leakage of information. overwrite blockhash * fix order of confirm loop and not found signature check * fix mocks * prevent confirmation loop to mark tx as errored when it needs to be rebroadcasted * fix test * fix pointer * add comments * reduce rpc calls + refactors * tests + check to save rpc calls * address feedback + remove redundant impl * iface comment * address feedback on compute unit limit and lastValidBlockHeight assignment * blockhash assignment inside txm.sendWithRetry * address feedback * Merge branch 'develop' into nonevm-706-support-custom-bumping-strategy-rpc-expiration-within-confirmation * refactors after merge * fix interactive rebase * fix whitespace diffs * fix import * fix mocks * add on prebroadcaste error * remove rebroadcast count and fix package * improve docs * track status on each signature to detect reorgs * move things arround + add reorg detection * linting errors * fix some state tracking instances * remove redundant sig update * move state from txes to sigs * fix listAllExpiredBroadcastedTxs * handle reorg after confirm cycle * associate sigs to retry ctx * remove unused ctx * add errored state and remove finalized * comment * Revert "comment" This reverts commit 6bc0c62da48d164e3e13e85edc7116d0f5025b4a. * Revert "remove unused ctx" This reverts commit 2902ec0b9d3450408d3117ce2aa16cbcfcc1195f. * Revert "associate sigs to retry ctx" This reverts commit 8c18891f3d21b922733bcede809fc216f1aba58a. * Revert "fix listAllExpiredBroadcastedTxs" This reverts commit f4c6069a6818bba90c73aed28b07909aac4859ea. * Revert "move state from txes to sigs" This reverts commit 3a6e643c35b0a31332713099dded6b811226f692. * fix tx state * address feedback * fix ci * fix lint * handle multiple sigs case * improve comment * improve logic and comments * fix comparison against blockHeight instead of slotHeight * address feedback * fix lint * fix log * address feedback * remove useless slot height * address feedback * add comment * tests and fix some bugs * address feedback * address feedback * validate that tx doesn't exist in any of maps when adding new tx * get height instead of whole block optimization * fix mocks on expiration * fix test * rebroadcast with new blockhash + add integration tests * fix integration tests * remove unused params and better comments * handle reorg equally for processed and confirmed at a sig level * add comments and rename txHasReorg to IsTxReorged for better readability * change test name to solve github CI failing check * fix ci * fix tests removing parallel * fix integration tests * capture range var * address feedback * address feedback --- pkg/solana/client/client.go | 15 + pkg/solana/client/client_test.go | 32 ++ pkg/solana/client/mocks/reader_writer.go | 56 +++ pkg/solana/config/config.go | 2 +- pkg/solana/txm/pendingtx.go | 257 ++++++++----- pkg/solana/txm/pendingtx_test.go | 448 ++++++++++++++++++----- pkg/solana/txm/txm.go | 177 ++++++--- pkg/solana/txm/txm_integration_test.go | 202 +++++++++- pkg/solana/txm/txm_internal_test.go | 242 ++++++++++-- pkg/solana/txm/utils/utils.go | 4 +- 10 files changed, 1143 insertions(+), 292 deletions(-) diff --git a/pkg/solana/client/client.go b/pkg/solana/client/client.go index a015fdc1f..5eaa37b89 100644 --- a/pkg/solana/client/client.go +++ b/pkg/solana/client/client.go @@ -36,6 +36,8 @@ type Reader interface { ChainID(ctx context.Context) (mn.StringID, error) GetFeeForMessage(ctx context.Context, msg string) (uint64, error) GetLatestBlock(ctx context.Context) (*rpc.GetBlockResult, error) + // GetLatestBlockHeight returns the latest block height of the node based on the configured commitment type + GetLatestBlockHeight(ctx context.Context) (uint64, error) GetTransaction(ctx context.Context, txHash solana.Signature, opts *rpc.GetTransactionOpts) (*rpc.GetTransactionResult, error) GetBlocks(ctx context.Context, startSlot uint64, endSlot *uint64) (rpc.BlocksResult, error) GetBlocksWithLimit(ctx context.Context, startSlot uint64, limit uint64) (*rpc.BlocksResult, error) @@ -331,6 +333,19 @@ func (c *Client) GetLatestBlock(ctx context.Context) (*rpc.GetBlockResult, error return v.(*rpc.GetBlockResult), err } +// GetLatestBlockHeight returns the latest block height of the node based on the configured commitment type +func (c *Client) GetLatestBlockHeight(ctx context.Context) (uint64, error) { + done := c.latency("latest_block_height") + defer done() + ctx, cancel := context.WithTimeout(ctx, c.txTimeout) + defer cancel() + + v, err, _ := c.requestGroup.Do("GetBlockHeight", func() (interface{}, error) { + return c.rpc.GetBlockHeight(ctx, c.commitment) + }) + return v.(uint64), err +} + func (c *Client) GetBlock(ctx context.Context, slot uint64) (*rpc.GetBlockResult, error) { // get block based on slot done := c.latency("get_block") diff --git a/pkg/solana/client/client_test.go b/pkg/solana/client/client_test.go index 8149b0839..6ca3c1727 100644 --- a/pkg/solana/client/client_test.go +++ b/pkg/solana/client/client_test.go @@ -125,6 +125,12 @@ func TestClient_Reader_Integration(t *testing.T) { assert.GreaterOrEqual(t, slot, startSlot) assert.LessOrEqual(t, slot, slot0) } + + // GetLatestBlockHeight + // Test fetching the latest block height + blockHeight, err := c.GetLatestBlockHeight(ctx) + require.NoError(t, err) + require.Greater(t, blockHeight, uint64(0), "Block height should be greater than 0") } func TestClient_Reader_ChainID(t *testing.T) { @@ -288,6 +294,32 @@ func TestClient_GetBlocks(t *testing.T) { requestTimeout, 500*time.Millisecond) } +func TestClient_GetLatestBlockHeight(t *testing.T) { + t.Parallel() + + ctx := tests.Context(t) + url := SetupLocalSolNode(t) + requestTimeout := 5 * time.Second + lggr := logger.Test(t) + cfg := config.NewDefault() + + // Initialize the client + c, err := NewClient(url, cfg, requestTimeout, lggr) + require.NoError(t, err) + + // Get the latest block height + blockHeight, err := c.GetLatestBlockHeight(ctx) + require.NoError(t, err) + require.Greater(t, blockHeight, uint64(0), "Block height should be greater than 0") + + // Wait until the block height increases + require.Eventually(t, func() bool { + newBlockHeight, err := c.GetLatestBlockHeight(ctx) + require.NoError(t, err) + return newBlockHeight > blockHeight + }, 10*time.Second, 1*time.Second, "Block height should eventually increase") +} + func TestClient_SendTxDuplicates_Integration(t *testing.T) { ctx := tests.Context(t) // set up environment diff --git a/pkg/solana/client/mocks/reader_writer.go b/pkg/solana/client/mocks/reader_writer.go index c64a4a9ad..7c72ca183 100644 --- a/pkg/solana/client/mocks/reader_writer.go +++ b/pkg/solana/client/mocks/reader_writer.go @@ -492,6 +492,62 @@ func (_c *ReaderWriter_GetLatestBlock_Call) RunAndReturn(run func(context.Contex return _c } +// GetLatestBlockHeight provides a mock function with given fields: ctx +func (_m *ReaderWriter) GetLatestBlockHeight(ctx context.Context) (uint64, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetLatestBlockHeight") + } + + var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (uint64, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) uint64); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(uint64) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReaderWriter_GetLatestBlockHeight_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLatestBlockHeight' +type ReaderWriter_GetLatestBlockHeight_Call struct { + *mock.Call +} + +// GetLatestBlockHeight is a helper method to define mock.On call +// - ctx context.Context +func (_e *ReaderWriter_Expecter) GetLatestBlockHeight(ctx interface{}) *ReaderWriter_GetLatestBlockHeight_Call { + return &ReaderWriter_GetLatestBlockHeight_Call{Call: _e.mock.On("GetLatestBlockHeight", ctx)} +} + +func (_c *ReaderWriter_GetLatestBlockHeight_Call) Run(run func(ctx context.Context)) *ReaderWriter_GetLatestBlockHeight_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *ReaderWriter_GetLatestBlockHeight_Call) Return(_a0 uint64, _a1 error) *ReaderWriter_GetLatestBlockHeight_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ReaderWriter_GetLatestBlockHeight_Call) RunAndReturn(run func(context.Context) (uint64, error)) *ReaderWriter_GetLatestBlockHeight_Call { + _c.Call.Return(run) + return _c +} + // GetSignaturesForAddressWithOpts provides a mock function with given fields: ctx, addr, opts func (_m *ReaderWriter) GetSignaturesForAddressWithOpts(ctx context.Context, addr solana.PublicKey, opts *rpc.GetSignaturesForAddressOpts) ([]*rpc.TransactionSignature, error) { ret := _m.Called(ctx, addr, opts) diff --git a/pkg/solana/config/config.go b/pkg/solana/config/config.go index 202b8fca8..7c5648dab 100644 --- a/pkg/solana/config/config.go +++ b/pkg/solana/config/config.go @@ -25,7 +25,7 @@ var defaultConfigSet = Chain{ MaxRetries: ptr(int64(0)), // max number of retries (default = 0). when config.MaxRetries < 0), interpreted as MaxRetries = nil and rpc node will do a reasonable number of retries // fee estimator - FeeEstimatorMode: ptr("fixed"), + FeeEstimatorMode: ptr("fixed"), // "fixed" or "blockhistory" ComputeUnitPriceMax: ptr(uint64(1_000)), ComputeUnitPriceMin: ptr(uint64(0)), ComputeUnitPriceDefault: ptr(uint64(0)), diff --git a/pkg/solana/txm/pendingtx.go b/pkg/solana/txm/pendingtx.go index 7784c47cd..181704a21 100644 --- a/pkg/solana/txm/pendingtx.go +++ b/pkg/solana/txm/pendingtx.go @@ -23,10 +23,11 @@ var ( ) type PendingTxContext interface { - // New adds a new tranasction in Broadcasted state to the storage - New(msg pendingTx, sig solana.Signature, cancel context.CancelFunc) error - // AddSignature adds a new signature for an existing transaction ID - AddSignature(id string, sig solana.Signature) error + // New adds a new transaction in Broadcasted state to the storage + New(msg pendingTx) error + // AddSignature adds a new signature to a broadcasted transaction in the pending transaction context. + // It associates the provided context and cancel function with the signature to manage retry and bumping cycles. + AddSignature(cancel context.CancelFunc, id string, sig solana.Signature) error // Remove removes transaction, context and related signatures from storage associated to given tx id if not in finalized or errored state Remove(id string) (string, error) // ListAllSigs returns all of the signatures being tracked for all transactions not yet finalized or errored @@ -50,6 +51,14 @@ type PendingTxContext interface { GetTxState(id string) (utils.TxState, error) // TrimFinalizedErroredTxs removes transactions that have reached their retention time TrimFinalizedErroredTxs() int + // IsTxReorged determines whether the given signature has experienced a re-org by comparing its in-memory state with its current on-chain state. + // A re-org is identified when the state of a signature regresses as follows: + // - Confirmed -> Processed || Broadcasted || Not Found + // - Processed -> Broadcasted || Not Found + // The function returns the transaction ID associated with the signature and a boolean indicating whether a re-org has occurred. + IsTxReorged(sig solana.Signature, currentState txmutils.TxState) (string, bool) + // GetPendingTx returns the pendingTx for the given ID if it exists + GetPendingTx(id string) (pendingTx, error) } // finishedTx is used to store info required to track transactions to finality or error @@ -66,14 +75,21 @@ type pendingTx struct { // finishedTx is used to store minimal info specifically for finalized or errored transactions for external status checks type finishedTx struct { retentionTs time.Time - state utils.TxState + state txmutils.TxState +} + +type txInfo struct { + // id of the transaction + id string + // state of the signature + state txmutils.TxState } var _ PendingTxContext = &pendingTxContext{} type pendingTxContext struct { - cancelBy map[string]context.CancelFunc - sigToID map[solana.Signature]string + cancelBy map[string]context.CancelFunc + sigToTxInfo map[solana.Signature]txInfo broadcastedProcessedTxs map[string]pendingTx // broadcasted and processed transactions that may require retry and bumping confirmedTxs map[string]pendingTx // transactions that require monitoring for re-org @@ -84,8 +100,8 @@ type pendingTxContext struct { func newPendingTxContext() *pendingTxContext { return &pendingTxContext{ - cancelBy: map[string]context.CancelFunc{}, - sigToID: map[solana.Signature]string{}, + cancelBy: map[string]context.CancelFunc{}, + sigToTxInfo: map[solana.Signature]txInfo{}, broadcastedProcessedTxs: map[string]pendingTx{}, confirmedTxs: map[string]pendingTx{}, @@ -93,12 +109,8 @@ func newPendingTxContext() *pendingTxContext { } } -func (c *pendingTxContext) New(tx pendingTx, sig solana.Signature, cancel context.CancelFunc) error { +func (c *pendingTxContext) New(tx pendingTx) error { err := c.withReadLock(func() error { - // validate signature does not exist - if _, exists := c.sigToID[sig]; exists { - return ErrSigAlreadyExists - } // Check if ID already exists in any of the maps if _, exists := c.broadcastedProcessedTxs[tx.id]; exists { return ErrIDAlreadyExists @@ -115,11 +127,8 @@ func (c *pendingTxContext) New(tx pendingTx, sig solana.Signature, cancel contex return err } - // upgrade to write lock if sig or id do not exist + // upgrade to write lock if id does not exist _, err = c.withWriteLock(func() (string, error) { - if _, exists := c.sigToID[sig]; exists { - return "", ErrSigAlreadyExists - } // Check if ID already exists in any of the maps if _, exists := c.broadcastedProcessedTxs[tx.id]; exists { return "", ErrIDAlreadyExists @@ -130,11 +139,7 @@ func (c *pendingTxContext) New(tx pendingTx, sig solana.Signature, cancel contex if _, exists := c.finalizedErroredTxs[tx.id]; exists { return "", ErrIDAlreadyExists } - // save cancel func - c.cancelBy[tx.id] = cancel - c.sigToID[sig] = tx.id - // add signature to tx - tx.signatures = append(tx.signatures, sig) + tx.signatures = []solana.Signature{} tx.createTs = time.Now() tx.state = utils.Broadcasted // save to the broadcasted map since transaction was just broadcasted @@ -144,10 +149,10 @@ func (c *pendingTxContext) New(tx pendingTx, sig solana.Signature, cancel contex return err } -func (c *pendingTxContext) AddSignature(id string, sig solana.Signature) error { +func (c *pendingTxContext) AddSignature(cancel context.CancelFunc, id string, sig solana.Signature) error { err := c.withReadLock(func() error { // signature already exists - if _, exists := c.sigToID[sig]; exists { + if _, exists := c.sigToTxInfo[sig]; exists { return ErrSigAlreadyExists } // new signatures should only be added for broadcasted transactions @@ -163,18 +168,24 @@ func (c *pendingTxContext) AddSignature(id string, sig solana.Signature) error { // upgrade to write lock if sig does not exist _, err = c.withWriteLock(func() (string, error) { - if _, exists := c.sigToID[sig]; exists { + if _, exists := c.sigToTxInfo[sig]; exists { return "", ErrSigAlreadyExists } if _, exists := c.broadcastedProcessedTxs[id]; !exists { return "", ErrTransactionNotFound } - c.sigToID[sig] = id + c.sigToTxInfo[sig] = txInfo{id: id, state: txmutils.Broadcasted} tx := c.broadcastedProcessedTxs[id] // save new signature tx.signatures = append(tx.signatures, sig) // save updated tx to broadcasted map c.broadcastedProcessedTxs[id] = tx + // set cancel context if not already set to handle reorgs when regressing from confirmed state + // previous context was removed so we associate a new context to our transaction to restart the retry/bumping cycle + if _, exists := c.cancelBy[id]; !exists { + c.cancelBy[id] = cancel + } + return "", nil }) return err @@ -186,7 +197,7 @@ func (c *pendingTxContext) Remove(id string) (string, error) { err := c.withReadLock(func() error { _, broadcastedIDExists := c.broadcastedProcessedTxs[id] _, confirmedIDExists := c.confirmedTxs[id] - // transcation does not exist in tx maps + // transaction does not exist in tx maps if !broadcastedIDExists && !confirmedIDExists { return ErrTransactionNotFound } @@ -216,7 +227,7 @@ func (c *pendingTxContext) Remove(id string) (string, error) { // remove all signatures associated with transaction from sig map for _, s := range tx.signatures { - delete(c.sigToID, s) + delete(c.sigToTxInfo, s) } return id, nil }) @@ -225,7 +236,7 @@ func (c *pendingTxContext) Remove(id string) (string, error) { func (c *pendingTxContext) ListAllSigs() []solana.Signature { c.lock.RLock() defer c.lock.RUnlock() - return maps.Keys(c.sigToID) + return maps.Keys(c.sigToTxInfo) } // ListAllExpiredBroadcastedTxs returns all the txes that are in broadcasted state and have expired for given block number compared against lastValidBlockHeight (last valid block number) @@ -250,14 +261,14 @@ func (c *pendingTxContext) Expired(sig solana.Signature, confirmationTimeout tim if confirmationTimeout == 0 { return false } - id, exists := c.sigToID[sig] + info, exists := c.sigToTxInfo[sig] if !exists { return false // return expired = false if timestamp does not exist (likely cleaned up by something else previously) } - if tx, exists := c.broadcastedProcessedTxs[id]; exists { + if tx, exists := c.broadcastedProcessedTxs[info.id]; exists { return time.Since(tx.createTs) > confirmationTimeout } - if tx, exists := c.confirmedTxs[id]; exists { + if tx, exists := c.confirmedTxs[info.id]; exists { return time.Since(tx.createTs) > confirmationTimeout } return false // return expired = false if tx does not exist (likely cleaned up by something else previously) @@ -266,12 +277,12 @@ func (c *pendingTxContext) Expired(sig solana.Signature, confirmationTimeout tim func (c *pendingTxContext) OnProcessed(sig solana.Signature) (string, error) { err := c.withReadLock(func() error { // validate if sig exists - id, sigExists := c.sigToID[sig] + info, sigExists := c.sigToTxInfo[sig] if !sigExists { return ErrSigDoesNotExist } // Transactions should only move to processed from broadcasted - tx, exists := c.broadcastedProcessedTxs[id] + tx, exists := c.broadcastedProcessedTxs[info.id] if !exists { return ErrTransactionNotFound } @@ -287,35 +298,36 @@ func (c *pendingTxContext) OnProcessed(sig solana.Signature) (string, error) { // upgrade to write lock if sig and id exist return c.withWriteLock(func() (string, error) { - id, sigExists := c.sigToID[sig] + info, sigExists := c.sigToTxInfo[sig] if !sigExists { - return id, ErrSigDoesNotExist + return info.id, ErrSigDoesNotExist } - tx, exists := c.broadcastedProcessedTxs[id] + tx, exists := c.broadcastedProcessedTxs[info.id] if !exists { - return id, ErrTransactionNotFound - } - // update tx state to Processed - tx.state = utils.Processed - // save updated tx back to the broadcasted map - c.broadcastedProcessedTxs[id] = tx - return id, nil + return info.id, ErrTransactionNotFound + } + // update sig and tx to Processed + info.state, tx.state = txmutils.Processed, txmutils.Processed + // save updated sig and tx back to the maps + c.sigToTxInfo[sig] = info + c.broadcastedProcessedTxs[info.id] = tx + return info.id, nil }) } func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { err := c.withReadLock(func() error { // validate if sig exists - id, sigExists := c.sigToID[sig] + info, sigExists := c.sigToTxInfo[sig] if !sigExists { return ErrSigDoesNotExist } // Check if transaction already in confirmed state - if tx, exists := c.confirmedTxs[id]; exists && tx.state == utils.Confirmed { + if tx, exists := c.confirmedTxs[info.id]; exists && tx.state == txmutils.Confirmed { return ErrAlreadyInExpectedState } // Transactions should only move to confirmed from broadcasted/processed - if _, exists := c.broadcastedProcessedTxs[id]; !exists { + if _, exists := c.broadcastedProcessedTxs[info.id]; !exists { return ErrTransactionNotFound } return nil @@ -326,38 +338,39 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { // upgrade to write lock if id exists return c.withWriteLock(func() (string, error) { - id, sigExists := c.sigToID[sig] + info, sigExists := c.sigToTxInfo[sig] if !sigExists { - return id, ErrSigDoesNotExist + return info.id, ErrSigDoesNotExist } - tx, exists := c.broadcastedProcessedTxs[id] + tx, exists := c.broadcastedProcessedTxs[info.id] if !exists { - return id, ErrTransactionNotFound + return info.id, ErrTransactionNotFound } // call cancel func + remove from map to stop the retry/bumping cycle for this transaction - if cancel, exists := c.cancelBy[id]; exists { + if cancel, exists := c.cancelBy[info.id]; exists { cancel() // cancel context - delete(c.cancelBy, id) + delete(c.cancelBy, info.id) } - // update tx state to Confirmed - tx.state = utils.Confirmed + // update sig and tx state to Confirmed + info.state, tx.state = txmutils.Confirmed, txmutils.Confirmed + c.sigToTxInfo[sig] = info // move tx to confirmed map - c.confirmedTxs[id] = tx + c.confirmedTxs[info.id] = tx // remove tx from broadcasted map - delete(c.broadcastedProcessedTxs, id) - return id, nil + delete(c.broadcastedProcessedTxs, info.id) + return info.id, nil }) } func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout time.Duration) (string, error) { err := c.withReadLock(func() error { - id, sigExists := c.sigToID[sig] + info, sigExists := c.sigToTxInfo[sig] if !sigExists { return ErrSigDoesNotExist } // Allow transactions to transition from broadcasted, processed, or confirmed state in case there are delays between status checks - _, broadcastedExists := c.broadcastedProcessedTxs[id] - _, confirmedExists := c.confirmedTxs[id] + _, broadcastedExists := c.broadcastedProcessedTxs[info.id] + _, confirmedExists := c.confirmedTxs[info.id] if !broadcastedExists && !confirmedExists { return ErrTransactionNotFound } @@ -369,47 +382,47 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti // upgrade to write lock if id exists return c.withWriteLock(func() (string, error) { - id, exists := c.sigToID[sig] + info, exists := c.sigToTxInfo[sig] if !exists { - return id, ErrSigDoesNotExist + return info.id, ErrSigDoesNotExist } var tx, tempTx pendingTx var broadcastedExists, confirmedExists bool - if tempTx, broadcastedExists = c.broadcastedProcessedTxs[id]; broadcastedExists { + if tempTx, broadcastedExists = c.broadcastedProcessedTxs[info.id]; broadcastedExists { tx = tempTx } - if tempTx, confirmedExists = c.confirmedTxs[id]; confirmedExists { + if tempTx, confirmedExists = c.confirmedTxs[info.id]; confirmedExists { tx = tempTx } if !broadcastedExists && !confirmedExists { - return id, ErrTransactionNotFound + return info.id, ErrTransactionNotFound } // call cancel func + remove from map to stop the retry/bumping cycle for this transaction // cancel is expected to be called and removed when tx is confirmed but checked here too in case state is skipped - if cancel, exists := c.cancelBy[id]; exists { + if cancel, exists := c.cancelBy[info.id]; exists { cancel() // cancel context - delete(c.cancelBy, id) + delete(c.cancelBy, info.id) } // delete from broadcasted map, if exists - delete(c.broadcastedProcessedTxs, id) + delete(c.broadcastedProcessedTxs, info.id) // delete from confirmed map, if exists - delete(c.confirmedTxs, id) - // remove all related signatures from the sigToID map to skip picking up this tx in the confirmation logic + delete(c.confirmedTxs, info.id) + // remove all related signatures from the sigToTxInfo map to skip picking up this tx in the confirmation logic for _, s := range tx.signatures { - delete(c.sigToID, s) + delete(c.sigToTxInfo, s) } // if retention duration is set to 0, delete transaction from storage // otherwise, move to finalized map if retentionTimeout == 0 { - return id, nil + return info.id, nil } finalizedTx := finishedTx{ state: utils.Finalized, retentionTs: time.Now().Add(retentionTimeout), } // move transaction from confirmed to finalized map - c.finalizedErroredTxs[id] = finalizedTx - return id, nil + c.finalizedErroredTxs[info.id] = finalizedTx + return info.id, nil }) } @@ -457,14 +470,14 @@ func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time. func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.Duration, txState utils.TxState, _ TxErrType) (string, error) { err := c.withReadLock(func() error { - id, sigExists := c.sigToID[sig] + info, sigExists := c.sigToTxInfo[sig] if !sigExists { return ErrSigDoesNotExist } // transaction can transition from any non-finalized state var broadcastedExists, confirmedExists bool - _, broadcastedExists = c.broadcastedProcessedTxs[id] - _, confirmedExists = c.confirmedTxs[id] + _, broadcastedExists = c.broadcastedProcessedTxs[info.id] + _, confirmedExists = c.confirmedTxs[info.id] // transcation does not exist in any tx maps if !broadcastedExists && !confirmedExists { return ErrTransactionNotFound @@ -477,16 +490,16 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D // upgrade to write lock if sig exists return c.withWriteLock(func() (string, error) { - id, exists := c.sigToID[sig] + info, exists := c.sigToTxInfo[sig] if !exists { return "", ErrSigDoesNotExist } var tx, tempTx pendingTx var broadcastedExists, confirmedExists bool - if tempTx, broadcastedExists = c.broadcastedProcessedTxs[id]; broadcastedExists { + if tempTx, broadcastedExists = c.broadcastedProcessedTxs[info.id]; broadcastedExists { tx = tempTx } - if tempTx, confirmedExists = c.confirmedTxs[id]; confirmedExists { + if tempTx, confirmedExists = c.confirmedTxs[info.id]; confirmedExists { tx = tempTx } // transcation does not exist in any non-finalized maps @@ -494,29 +507,29 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D return "", ErrTransactionNotFound } // call cancel func + remove from map - if cancel, exists := c.cancelBy[id]; exists { + if cancel, exists := c.cancelBy[info.id]; exists { cancel() // cancel context - delete(c.cancelBy, id) + delete(c.cancelBy, info.id) } // delete from broadcasted map, if exists - delete(c.broadcastedProcessedTxs, id) + delete(c.broadcastedProcessedTxs, info.id) // delete from confirmed map, if exists - delete(c.confirmedTxs, id) - // remove all related signatures from the sigToID map to skip picking up this tx in the confirmation logic + delete(c.confirmedTxs, info.id) + // remove all related signatures from the sigToTxInfo map to skip picking up this tx in the confirmation logic for _, s := range tx.signatures { - delete(c.sigToID, s) + delete(c.sigToTxInfo, s) } // if retention duration is set to 0, skip adding transaction to the errored map if retentionTimeout == 0 { - return id, nil + return info.id, nil } erroredTx := finishedTx{ state: txState, retentionTs: time.Now().Add(retentionTimeout), } // move transaction from broadcasted to error map - c.finalizedErroredTxs[id] = erroredTx - return id, nil + c.finalizedErroredTxs[info.id] = erroredTx + return info.id, nil }) } @@ -563,6 +576,52 @@ func (c *pendingTxContext) TrimFinalizedErroredTxs() int { return len(expiredIDs) } +func (c *pendingTxContext) IsTxReorged(sig solana.Signature, sigOnChainState txmutils.TxState) (string, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + + // Grab in memory state of the signature + txInfo, exists := c.sigToTxInfo[sig] + if !exists { + return "", false + } + + // Compare our in-memory state of the sig with the current on-chain state to determine if the sig had a regression + sigInMemoryState := txInfo.state + var hasReorg bool + switch sigInMemoryState { + case txmutils.Confirmed: + if sigOnChainState == txmutils.Processed || sigOnChainState == txmutils.Broadcasted || sigOnChainState == txmutils.NotFound { + hasReorg = true + } + case txmutils.Processed: + if sigOnChainState == txmutils.Broadcasted || sigOnChainState == txmutils.NotFound { + hasReorg = true + } + default: // No reorg if the signature is not in a state that can be reorged + } + + return txInfo.id, hasReorg +} + +func (c *pendingTxContext) GetPendingTx(id string) (pendingTx, error) { + c.lock.RLock() + defer c.lock.RUnlock() + var tx, tempTx pendingTx + var broadcastedExists, confirmedExists bool + if tempTx, broadcastedExists = c.broadcastedProcessedTxs[id]; broadcastedExists { + tx = tempTx + } + if tempTx, confirmedExists = c.confirmedTxs[id]; confirmedExists { + tx = tempTx + } + + if !broadcastedExists && !confirmedExists { + return pendingTx{}, ErrTransactionNotFound + } + return tx, nil +} + func (c *pendingTxContext) withReadLock(fn func() error) error { c.lock.RLock() defer c.lock.RUnlock() @@ -600,12 +659,12 @@ func newPendingTxContextWithProm(id string) *pendingTxContextWithProm { } } -func (c *pendingTxContextWithProm) New(msg pendingTx, sig solana.Signature, cancel context.CancelFunc) error { - return c.pendingTx.New(msg, sig, cancel) +func (c *pendingTxContextWithProm) New(msg pendingTx) error { + return c.pendingTx.New(msg) } -func (c *pendingTxContextWithProm) AddSignature(id string, sig solana.Signature) error { - return c.pendingTx.AddSignature(id, sig) +func (c *pendingTxContextWithProm) AddSignature(cancel context.CancelFunc, id string, sig solana.Signature) error { + return c.pendingTx.AddSignature(cancel, id, sig) } func (c *pendingTxContextWithProm) OnProcessed(sig solana.Signature) (string, error) { @@ -689,3 +748,11 @@ func (c *pendingTxContextWithProm) GetTxState(id string) (utils.TxState, error) func (c *pendingTxContextWithProm) TrimFinalizedErroredTxs() int { return c.pendingTx.TrimFinalizedErroredTxs() } + +func (c *pendingTxContextWithProm) IsTxReorged(sig solana.Signature, currentSigState txmutils.TxState) (string, bool) { + return c.pendingTx.IsTxReorged(sig, currentSigState) +} + +func (c *pendingTxContextWithProm) GetPendingTx(id string) (pendingTx, error) { + return c.pendingTx.GetPendingTx(id) +} diff --git a/pkg/solana/txm/pendingtx_test.go b/pkg/solana/txm/pendingtx_test.go index 183f00c8e..9b8a050cc 100644 --- a/pkg/solana/txm/pendingtx_test.go +++ b/pkg/solana/txm/pendingtx_test.go @@ -15,6 +15,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) func TestPendingTxContext_add_remove_multiple(t *testing.T) { @@ -42,17 +43,19 @@ func TestPendingTxContext_add_remove_multiple(t *testing.T) { for i := 0; i < n; i++ { sig, cancel := newProcess() msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + assert.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) assert.NoError(t, err) ids[sig] = msg.id } // cannot add signature for non existent ID - require.Error(t, txs.AddSignature(uuid.New().String(), solana.Signature{})) + require.Error(t, txs.AddSignature(func() {}, uuid.New().String(), solana.Signature{})) list := make([]string, 0, n) - for _, id := range txs.sigToID { - list = append(list, id) + for _, info := range txs.sigToTxInfo { + list = append(list, info.id) } assert.Equal(t, n, len(list)) @@ -79,13 +82,15 @@ func TestPendingTxContext_new(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err, "expected no error when adding a new transaction") // Check it exists in signature map and mapped to the correct txID - id, exists := txs.sigToID[sig] + txInfo, exists := txs.sigToTxInfo[sig] require.True(t, exists, "signature should exist in sigToID map") - require.Equal(t, msg.id, id, "signature should map to correct transaction ID") + require.Equal(t, msg.id, txInfo.id, "signature should map to correct transaction ID") // Check it exists in broadcasted map and that sigs match tx, exists := txs.broadcastedProcessedTxs[msg.id] @@ -102,32 +107,24 @@ func TestPendingTxContext_new(t *testing.T) { _, exists = txs.finalizedErroredTxs[msg.id] require.False(t, exists, "transaction should not exist in finalizedErroredTxs map") - // Attempt to add the same transaction again with the same signature - err = txs.New(msg, sig, cancel) - require.ErrorIs(t, err, ErrSigAlreadyExists, "expected ErrSigAlreadyExists when adding duplicate signature") - - // Attempt to add a new transaction with the same transaction ID but different signature - err = txs.New(pendingTx{id: msg.id}, randomSignature(t), cancel) + // Attempt to add the same transaction again + err = txs.New(msg) require.ErrorIs(t, err, ErrIDAlreadyExists, "expected ErrIDAlreadyExists when adding duplicate transaction ID") - // Attempt to add a new transaction with a different transaction ID but same signature - err = txs.New(pendingTx{id: uuid.NewString()}, sig, cancel) - require.ErrorIs(t, err, ErrSigAlreadyExists, "expected ErrSigAlreadyExists when adding duplicate signature") - // Simulate moving the transaction to confirmedTxs map _, err = txs.OnConfirmed(sig) require.NoError(t, err, "expected no error when confirming transaction") - // Attempt to add a new transaction with the same ID (now in confirmedTxs) and new signature - err = txs.New(pendingTx{id: msg.id}, randomSignature(t), cancel) + // Attempt to add a new transaction with the same ID (now in confirmedTxs) + err = txs.New(pendingTx{id: msg.id}) require.ErrorIs(t, err, ErrIDAlreadyExists, "expected ErrIDAlreadyExists when adding transaction ID that exists in confirmedTxs") // Simulate moving the transaction to finalizedErroredTxs map _, err = txs.OnFinalized(sig, 10*time.Second) require.NoError(t, err, "expected no error when finalizing transaction") - // Attempt to add a new transaction with the same ID (now in finalizedErroredTxs) and new signature - err = txs.New(pendingTx{id: msg.id}, randomSignature(t), cancel) + // Attempt to add a new transaction with the same ID (now in finalizedErroredTxs) + err = txs.New(pendingTx{id: msg.id}) require.ErrorIs(t, err, ErrIDAlreadyExists, "expected ErrIDAlreadyExists when adding transaction ID that exists in finalizedErroredTxs") } @@ -142,19 +139,21 @@ func TestPendingTxContext_add_signature(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig1, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig1) require.NoError(t, err) - err = txs.AddSignature(msg.id, sig2) + err = txs.AddSignature(cancel, msg.id, sig2) require.NoError(t, err) // Check signature map - id, exists := txs.sigToID[sig1] + txInfo, exists := txs.sigToTxInfo[sig1] require.True(t, exists) - require.Equal(t, msg.id, id) - id, exists = txs.sigToID[sig2] + require.Equal(t, msg.id, txInfo.id) + txInfo, exists = txs.sigToTxInfo[sig2] require.True(t, exists) - require.Equal(t, msg.id, id) + require.Equal(t, msg.id, txInfo.id) // Check broadcasted map tx, exists := txs.broadcastedProcessedTxs[msg.id] @@ -177,10 +176,12 @@ func TestPendingTxContext_add_signature(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) - err = txs.AddSignature(msg.id, sig) + err = txs.AddSignature(cancel, msg.id, sig) require.ErrorIs(t, err, ErrSigAlreadyExists) }) @@ -190,10 +191,12 @@ func TestPendingTxContext_add_signature(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig1, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig1) require.NoError(t, err) - err = txs.AddSignature("bad id", sig2) + err = txs.AddSignature(cancel, "bad id", sig2) require.ErrorIs(t, err, ErrTransactionNotFound) }) @@ -203,7 +206,9 @@ func TestPendingTxContext_add_signature(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig1, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig1) require.NoError(t, err) // Transition to processed state @@ -216,7 +221,7 @@ func TestPendingTxContext_add_signature(t *testing.T) { require.NoError(t, err) require.Equal(t, msg.id, id) - err = txs.AddSignature(msg.id, sig2) + err = txs.AddSignature(cancel, msg.id, sig2) require.ErrorIs(t, err, ErrTransactionNotFound) }) } @@ -232,7 +237,9 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to processed state @@ -241,9 +248,9 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { require.Equal(t, msg.id, id) // Check it exists in signature map - id, exists := txs.sigToID[sig] + txInfo, exists := txs.sigToTxInfo[sig] require.True(t, exists) - require.Equal(t, msg.id, id) + require.Equal(t, msg.id, txInfo.id) // Check it exists in broadcasted map tx, exists := txs.broadcastedProcessedTxs[msg.id] @@ -268,7 +275,9 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to processed state @@ -291,7 +300,9 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to processed state @@ -319,7 +330,9 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to errored state @@ -337,7 +350,9 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to processed state @@ -362,7 +377,9 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to processed state @@ -376,9 +393,9 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { require.Equal(t, msg.id, id) // Check it exists in signature map - id, exists := txs.sigToID[sig] + txInfo, exists := txs.sigToTxInfo[sig] require.True(t, exists) - require.Equal(t, msg.id, id) + require.Equal(t, msg.id, txInfo.id) // Check it does not exist in broadcasted map _, exists = txs.broadcastedProcessedTxs[msg.id] @@ -403,7 +420,9 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to processed state @@ -431,7 +450,9 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to errored state @@ -449,7 +470,9 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to processed state @@ -480,11 +503,13 @@ func TestPendingTxContext_on_finalized(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig1, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig1) require.NoError(t, err) // Add second signature - err = txs.AddSignature(msg.id, sig2) + err = txs.AddSignature(cancel, msg.id, sig2) require.NoError(t, err) // Transition to finalized state @@ -508,9 +533,9 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.Equal(t, utils.Finalized, tx.state) // Check sigs do no exist in signature map - _, exists = txs.sigToID[sig1] + _, exists = txs.sigToTxInfo[sig1] require.False(t, exists) - _, exists = txs.sigToID[sig2] + _, exists = txs.sigToTxInfo[sig2] require.False(t, exists) }) @@ -520,11 +545,13 @@ func TestPendingTxContext_on_finalized(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig1, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig1) require.NoError(t, err) // Add second signature - err = txs.AddSignature(msg.id, sig2) + err = txs.AddSignature(cancel, msg.id, sig2) require.NoError(t, err) // Transition to processed state @@ -558,9 +585,9 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.Equal(t, utils.Finalized, tx.state) // Check sigs do no exist in signature map - _, exists = txs.sigToID[sig1] + _, exists = txs.sigToTxInfo[sig1] require.False(t, exists) - _, exists = txs.sigToID[sig2] + _, exists = txs.sigToTxInfo[sig2] require.False(t, exists) }) @@ -569,7 +596,9 @@ func TestPendingTxContext_on_finalized(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig1, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig1) require.NoError(t, err) // Transition to processed state @@ -600,7 +629,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.False(t, exists) // Check sigs do no exist in signature map - _, exists = txs.sigToID[sig1] + _, exists = txs.sigToTxInfo[sig1] require.False(t, exists) }) @@ -609,7 +638,9 @@ func TestPendingTxContext_on_finalized(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to errored state @@ -634,7 +665,9 @@ func TestPendingTxContext_on_error(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to errored state @@ -658,7 +691,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, utils.Errored, tx.state) // Check sigs do no exist in signature map - _, exists = txs.sigToID[sig] + _, exists = txs.sigToTxInfo[sig] require.False(t, exists) }) @@ -667,7 +700,9 @@ func TestPendingTxContext_on_error(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to errored state @@ -696,7 +731,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, utils.Errored, tx.state) // Check sigs do no exist in signature map - _, exists = txs.sigToID[sig] + _, exists = txs.sigToTxInfo[sig] require.False(t, exists) }) @@ -705,7 +740,9 @@ func TestPendingTxContext_on_error(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to fatally errored state @@ -725,7 +762,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, utils.FatallyErrored, tx.state) // Check sigs do no exist in signature map - _, exists = txs.sigToID[sig] + _, exists = txs.sigToTxInfo[sig] require.False(t, exists) }) @@ -734,10 +771,12 @@ func TestPendingTxContext_on_error(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) - // Transition to errored state + // Transition to confirmed state id, err := txs.OnConfirmed(sig) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -760,7 +799,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.False(t, exists) // Check sigs do no exist in signature map - _, exists = txs.sigToID[sig] + _, exists = txs.sigToTxInfo[sig] require.False(t, exists) }) @@ -769,16 +808,18 @@ func TestPendingTxContext_on_error(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) - // Transition to confirmed state + // Transition to finalized state id, err := txs.OnFinalized(sig, retentionTimeout) require.NoError(t, err) require.Equal(t, msg.id, id) - // Transition back to confirmed state - id, err = txs.OnError(sig, retentionTimeout, utils.Errored, 0) + // Transition to errored state + id, err = txs.OnError(sig, retentionTimeout, txmutils.Errored, 0) require.Error(t, err) require.Equal(t, "", id) }) @@ -827,7 +868,9 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} // Add transaction to broadcasted map - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + require.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) require.NoError(t, err) // Transition to errored state @@ -869,14 +912,18 @@ func TestPendingTxContext_remove(t *testing.T) { // Create new broadcasted transaction with extra sig broadcastedMsg := pendingTx{id: broadcastedID} - err := txs.New(broadcastedMsg, broadcastedSig1, cancel) + err := txs.New(broadcastedMsg) + require.NoError(t, err) + err = txs.AddSignature(cancel, broadcastedMsg.id, broadcastedSig1) require.NoError(t, err) - err = txs.AddSignature(broadcastedMsg.id, broadcastedSig2) + err = txs.AddSignature(cancel, broadcastedMsg.id, broadcastedSig2) require.NoError(t, err) // Create new processed transaction processedMsg := pendingTx{id: processedID} - err = txs.New(processedMsg, processedSig, cancel) + err = txs.New(processedMsg) + require.NoError(t, err) + err = txs.AddSignature(cancel, processedMsg.id, processedSig) require.NoError(t, err) id, err := txs.OnProcessed(processedSig) require.NoError(t, err) @@ -884,7 +931,9 @@ func TestPendingTxContext_remove(t *testing.T) { // Create new confirmed transaction confirmedMsg := pendingTx{id: confirmedID} - err = txs.New(confirmedMsg, confirmedSig, cancel) + err = txs.New(confirmedMsg) + require.NoError(t, err) + err = txs.AddSignature(cancel, confirmedMsg.id, confirmedSig) require.NoError(t, err) id, err = txs.OnConfirmed(confirmedSig) require.NoError(t, err) @@ -892,7 +941,9 @@ func TestPendingTxContext_remove(t *testing.T) { // Create new finalized transaction finalizedMsg := pendingTx{id: finalizedID} - err = txs.New(finalizedMsg, finalizedSig, cancel) + err = txs.New(finalizedMsg) + require.NoError(t, err) + err = txs.AddSignature(cancel, finalizedMsg.id, finalizedSig) require.NoError(t, err) id, err = txs.OnFinalized(finalizedSig, retentionTimeout) require.NoError(t, err) @@ -900,9 +951,11 @@ func TestPendingTxContext_remove(t *testing.T) { // Create new errored transaction erroredMsg := pendingTx{id: erroredID} - err = txs.New(erroredMsg, erroredSig, cancel) + err = txs.New(erroredMsg) + require.NoError(t, err) + err = txs.AddSignature(cancel, erroredMsg.id, erroredSig) require.NoError(t, err) - id, err = txs.OnError(erroredSig, retentionTimeout, utils.Errored, 0) + id, err = txs.OnError(erroredSig, retentionTimeout, txmutils.Errored, 0) require.NoError(t, err) require.Equal(t, erroredMsg.id, id) @@ -910,35 +963,41 @@ func TestPendingTxContext_remove(t *testing.T) { id, err = txs.Remove(broadcastedID) require.NoError(t, err) require.Equal(t, broadcastedMsg.id, id) + // Check removed from broadcasted map _, exists := txs.broadcastedProcessedTxs[broadcastedMsg.id] require.False(t, exists) + // Check all signatures removed from sig map - _, exists = txs.sigToID[broadcastedSig1] + _, exists = txs.sigToTxInfo[broadcastedSig1] require.False(t, exists) - _, exists = txs.sigToID[broadcastedSig2] + _, exists = txs.sigToTxInfo[broadcastedSig2] require.False(t, exists) // Remove processed transaction id, err = txs.Remove(processedID) require.NoError(t, err) require.Equal(t, processedMsg.id, id) + // Check removed from broadcasted map _, exists = txs.broadcastedProcessedTxs[processedMsg.id] require.False(t, exists) + // Check all signatures removed from sig map - _, exists = txs.sigToID[processedSig] + _, exists = txs.sigToTxInfo[processedSig] require.False(t, exists) // Remove confirmed transaction id, err = txs.Remove(confirmedID) require.NoError(t, err) require.Equal(t, confirmedMsg.id, id) + // Check removed from confirmed map _, exists = txs.confirmedTxs[confirmedMsg.id] require.False(t, exists) + // Check all signatures removed from sig map - _, exists = txs.sigToID[confirmedSig] + _, exists = txs.sigToTxInfo[confirmedSig] require.False(t, exists) // Check remove cannot be called on finalized transaction @@ -997,7 +1056,9 @@ func TestPendingTxContext_expired(t *testing.T) { txID := uuid.NewString() msg := pendingTx{id: txID} - err := txs.New(msg, sig, cancel) + err := txs.New(msg) + assert.NoError(t, err) + err = txs.AddSignature(cancel, msg.id, sig) assert.NoError(t, err) msg, exists := txs.broadcastedProcessedTxs[msg.id] @@ -1021,15 +1082,16 @@ func TestPendingTxContext_race(t *testing.T) { t.Run("new", func(t *testing.T) { txCtx := newPendingTxContext() var wg sync.WaitGroup + txID := uuid.NewString() wg.Add(2) var err [2]error go func() { - err[0] = txCtx.New(pendingTx{id: uuid.NewString()}, solana.Signature{}, func() {}) + err[0] = txCtx.New(pendingTx{id: txID}) wg.Done() }() go func() { - err[1] = txCtx.New(pendingTx{id: uuid.NewString()}, solana.Signature{}, func() {}) + err[1] = txCtx.New(pendingTx{id: txID}) wg.Done() }() @@ -1040,18 +1102,18 @@ func TestPendingTxContext_race(t *testing.T) { t.Run("add signature", func(t *testing.T) { txCtx := newPendingTxContext() msg := pendingTx{id: uuid.NewString()} - createErr := txCtx.New(msg, solana.Signature{}, func() {}) + createErr := txCtx.New(msg) require.NoError(t, createErr) var wg sync.WaitGroup wg.Add(2) var err [2]error go func() { - err[0] = txCtx.AddSignature(msg.id, solana.Signature{1}) + err[0] = txCtx.AddSignature(func() {}, msg.id, solana.Signature{1}) wg.Done() }() go func() { - err[1] = txCtx.AddSignature(msg.id, solana.Signature{1}) + err[1] = txCtx.AddSignature(func() {}, msg.id, solana.Signature{1}) wg.Done() }() @@ -1063,16 +1125,18 @@ func TestPendingTxContext_race(t *testing.T) { txCtx := newPendingTxContext() txID := uuid.NewString() msg := pendingTx{id: txID} - err := txCtx.New(msg, solana.Signature{}, func() {}) + err := txCtx.New(msg) require.NoError(t, err) var wg sync.WaitGroup wg.Add(2) go func() { + assert.NotPanics(t, func() { txCtx.Remove(txID) }) //nolint // no need to check error assert.NotPanics(t, func() { txCtx.Remove(txID) }) //nolint // no need to check error wg.Done() }() go func() { + assert.NotPanics(t, func() { txCtx.Remove(txID) }) //nolint // no need to check error assert.NotPanics(t, func() { txCtx.Remove(txID) }) //nolint // no need to check error wg.Done() }() @@ -1096,17 +1160,22 @@ func TestGetTxState(t *testing.T) { // Create new broadcasted transaction with extra sig broadcastedMsg := pendingTx{id: uuid.NewString()} - err := txs.New(broadcastedMsg, broadcastedSig, cancel) + err := txs.New(broadcastedMsg) + require.NoError(t, err) + err = txs.AddSignature(cancel, broadcastedMsg.id, broadcastedSig) require.NoError(t, err) - var state utils.TxState // Create new processed transaction + var state txmutils.TxState processedMsg := pendingTx{id: uuid.NewString()} - err = txs.New(processedMsg, processedSig, cancel) + err = txs.New(processedMsg) + require.NoError(t, err) + err = txs.AddSignature(cancel, processedMsg.id, processedSig) require.NoError(t, err) id, err := txs.OnProcessed(processedSig) require.NoError(t, err) require.Equal(t, processedMsg.id, id) + // Check Processed state is returned state, err = txs.GetTxState(processedMsg.id) require.NoError(t, err) @@ -1114,11 +1183,14 @@ func TestGetTxState(t *testing.T) { // Create new confirmed transaction confirmedMsg := pendingTx{id: uuid.NewString()} - err = txs.New(confirmedMsg, confirmedSig, cancel) + err = txs.New(confirmedMsg) + require.NoError(t, err) + err = txs.AddSignature(cancel, confirmedMsg.id, confirmedSig) require.NoError(t, err) id, err = txs.OnConfirmed(confirmedSig) require.NoError(t, err) require.Equal(t, confirmedMsg.id, id) + // Check Confirmed state is returned state, err = txs.GetTxState(confirmedMsg.id) require.NoError(t, err) @@ -1126,11 +1198,14 @@ func TestGetTxState(t *testing.T) { // Create new finalized transaction finalizedMsg := pendingTx{id: uuid.NewString()} - err = txs.New(finalizedMsg, finalizedSig, cancel) + err = txs.New(finalizedMsg) + require.NoError(t, err) + err = txs.AddSignature(cancel, finalizedMsg.id, finalizedSig) require.NoError(t, err) id, err = txs.OnFinalized(finalizedSig, retentionTimeout) require.NoError(t, err) require.Equal(t, finalizedMsg.id, id) + // Check Finalized state is returned state, err = txs.GetTxState(finalizedMsg.id) require.NoError(t, err) @@ -1138,11 +1213,14 @@ func TestGetTxState(t *testing.T) { // Create new errored transaction erroredMsg := pendingTx{id: uuid.NewString()} - err = txs.New(erroredMsg, erroredSig, cancel) + err = txs.New(erroredMsg) + require.NoError(t, err) + err = txs.AddSignature(cancel, erroredMsg.id, erroredSig) require.NoError(t, err) - id, err = txs.OnError(erroredSig, retentionTimeout, utils.Errored, 0) + id, err = txs.OnError(erroredSig, retentionTimeout, txmutils.Errored, 0) require.NoError(t, err) require.Equal(t, erroredMsg.id, id) + // Check Errored state is returned state, err = txs.GetTxState(erroredMsg.id) require.NoError(t, err) @@ -1150,11 +1228,14 @@ func TestGetTxState(t *testing.T) { // Create new fatally errored transaction fatallyErroredMsg := pendingTx{id: uuid.NewString()} - err = txs.New(fatallyErroredMsg, fatallyErroredSig, cancel) + err = txs.New(fatallyErroredMsg) + require.NoError(t, err) + err = txs.AddSignature(cancel, fatallyErroredMsg.id, fatallyErroredSig) require.NoError(t, err) - id, err = txs.OnError(fatallyErroredSig, retentionTimeout, utils.FatallyErrored, 0) + id, err = txs.OnError(fatallyErroredSig, retentionTimeout, txmutils.FatallyErrored, 0) require.NoError(t, err) require.Equal(t, fatallyErroredMsg.id, id) + // Check Errored state is returned state, err = txs.GetTxState(fatallyErroredMsg.id) require.NoError(t, err) @@ -1327,3 +1408,174 @@ func TestPendingTxContext_ListAllExpiredBroadcastedTxs(t *testing.T) { }) } } + +func createTxAndAddSig(t *testing.T, txs *pendingTxContext) (string, solana.Signature) { + sig := randomSignature(t) + txID := uuid.NewString() + tx := pendingTx{id: txID} + require.NoError(t, txs.New(tx)) + require.NoError(t, txs.AddSignature(func() {}, txID, sig)) + return txID, sig +} + +func TestPendingTxContext_IsTxReorged(t *testing.T) { + t.Parallel() + txs := newPendingTxContext() + + // This helper creates a brand new transaction/signature, + // then sets the in-memory state to the provided memoryState + setMemoryState := func(t *testing.T, txs *pendingTxContext, memoryState txmutils.TxState) (txID string, sig solana.Signature) { + txID, sig = createTxAndAddSig(t, txs) + + switch memoryState { + case txmutils.Processed: + _, err := txs.OnProcessed(sig) + require.NoError(t, err, "OnProcessed should succeed") + case txmutils.Confirmed: + _, err := txs.OnProcessed(sig) + require.NoError(t, err) + _, err = txs.OnConfirmed(sig) + require.NoError(t, err, "OnConfirmed should succeed") + case txmutils.Broadcasted: // do nothing; newly created sig is in memory=Broadcasted by default + default: + require.FailNowf(t, "unexpected memory state", "%v", memoryState) + } + return + } + + tests := []struct { + name string + memoryState txmutils.TxState + chainState txmutils.TxState + wantReorg bool + }{ + { + name: "non-existent signature => no reorg", + memoryState: txmutils.Broadcasted, // doesn't matter, we'll handle this case specially + chainState: txmutils.Broadcasted, + wantReorg: false, + }, + { + name: "memory=Confirmed, chain=Confirmed => no reorg", + memoryState: txmutils.Confirmed, + chainState: txmutils.Confirmed, + wantReorg: false, + }, + { + name: "memory=Confirmed, chain=Processed => reorg", + memoryState: txmutils.Confirmed, + chainState: txmutils.Processed, + wantReorg: true, + }, + { + name: "memory=Confirmed, chain=NotFound => reorg", + memoryState: txmutils.Confirmed, + chainState: txmutils.NotFound, + wantReorg: true, + }, + { + name: "memory=Processed, chain=Confirmed => no reorg", + memoryState: txmutils.Processed, + chainState: txmutils.Confirmed, + wantReorg: false, + }, + { + name: "memory=Processed, chain=Processed => no reorg", + memoryState: txmutils.Processed, + chainState: txmutils.Processed, + wantReorg: false, + }, + { + name: "memory=Processed, chain=NotFound => reorg", + memoryState: txmutils.Processed, + chainState: txmutils.NotFound, + wantReorg: true, + }, + { + name: "memory=Broadcasted, chain=Confirmed => no reorg", + memoryState: txmutils.Broadcasted, + chainState: txmutils.Confirmed, + wantReorg: false, + }, + { + name: "memory=Broadcasted, chain=Processed => no reorg", + memoryState: txmutils.Broadcasted, + chainState: txmutils.Processed, + wantReorg: false, + }, + { + name: "memory=Broadcasted, chain=NotFound => no reorg", + memoryState: txmutils.Broadcasted, + chainState: txmutils.NotFound, + wantReorg: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // handle special case + if tt.name == "non-existent signature => no reorg" { + // don't create any signature in memory + txID, hasReorg := txs.IsTxReorged(randomSignature(t), tt.chainState) + require.False(t, hasReorg, "expected no reorg for unknown sig") + require.Empty(t, txID, "expected empty txID for unknown sig") + return + } + + // create + set memory state, run IsTxReorged and assert for all other test cases + creationTxID, sig := setMemoryState(t, txs, tt.memoryState) + returnedTxID, hasReorg := txs.IsTxReorged(sig, tt.chainState) + require.Equal(t, creationTxID, returnedTxID, "expected same txID") + if tt.wantReorg { + require.True(t, hasReorg, "expected reorg for memory=%v, chain=%v", tt.memoryState, tt.chainState) + } else { + require.False(t, hasReorg, "expected no reorg for memory=%v, chain=%v", tt.memoryState, tt.chainState) + } + }) + } +} + +func TestPendingTxContext_GetPendingTx(t *testing.T) { + t.Parallel() + txs := newPendingTxContext() + + t.Run("successfully retrieve broadcasted transaction", func(t *testing.T) { + txID, _ := createTxAndAddSig(t, txs) + + tx, err := txs.GetPendingTx(txID) + require.NoError(t, err) + require.Equal(t, txID, tx.id) + require.Equal(t, utils.Broadcasted, tx.state) + }) + + t.Run("successfully retrieve processed transaction", func(t *testing.T) { + txID, sig := createTxAndAddSig(t, txs) + _, err := txs.OnProcessed(sig) + require.NoError(t, err) + + tx, err := txs.GetPendingTx(txID) + require.NoError(t, err) + require.Equal(t, txID, tx.id) + require.Equal(t, utils.Processed, tx.state) + }) + + t.Run("successfully retrieve confirmed transaction", func(t *testing.T) { + txID, sig := createTxAndAddSig(t, txs) + _, err := txs.OnProcessed(sig) + require.NoError(t, err) + _, err = txs.OnConfirmed(sig) + require.NoError(t, err) + + tx, err := txs.GetPendingTx(txID) + require.NoError(t, err) + require.Equal(t, txID, tx.id) + require.Equal(t, utils.Confirmed, tx.state) + }) + + t.Run("fail to retrieve non-existent transaction", func(t *testing.T) { + _, err := txs.GetPendingTx("non-existent-id") + require.ErrorIs(t, err, ErrTransactionNotFound) + }) +} diff --git a/pkg/solana/txm/txm.go b/pkg/solana/txm/txm.go index c87089060..1c7e4c54e 100644 --- a/pkg/solana/txm/txm.go +++ b/pkg/solana/txm/txm.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/gagliardetto/solana-go" solanaGo "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" "github.com/google/uuid" @@ -193,26 +194,25 @@ func (txm *Txm) sendWithRetry(ctx context.Context, msg pendingTx) (solanaGo.Tran return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("tx failed initial transmit: %w", errors.Join(initSendErr, stateTransitionErr)) } - // Store tx signature and cancel function - if err := txm.txs.New(msg, sig, cancel); err != nil { - cancel() // Cancel context when exiting early - return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to save tx signature (%s) to inflight txs: %w", sig, err) + // Create new transaction in memory + if err := txm.txs.New(msg); err != nil { + cancel() + return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to create new transaction: %w", err) } - txm.lggr.Debugw("tx initial broadcast", "id", msg.id, "fee", msg.cfg.BaseComputeUnitPrice, "signature", sig, "lastValidBlockHeight", msg.lastValidBlockHeight) - - // Initialize signature list with initialTx signature. This list will be used to add new signatures and track retry attempts. - sigs := &txmutils.SignatureList{} - sigs.Allocate() - if initSetErr := sigs.Set(0, sig); initSetErr != nil { - return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to save initial signature in signature list: %w", initSetErr) + // Associate initial signature and cancel func to tx + if err := txm.txs.AddSignature(cancel, msg.id, sig); err != nil { + cancel() + return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to save initial signature (%s) to inflight txs: %w", sig, err) } + txm.lggr.Debugw("tx initial broadcast", "id", msg.id, "fee", msg.cfg.BaseComputeUnitPrice, "signature", sig, "lastValidBlockHeight", msg.lastValidBlockHeight) + // pass in copy of msg (to build new tx with bumped fee) and broadcasted tx == initTx (to retry tx without bumping) txm.done.Add(1) go func() { defer txm.done.Done() - txm.retryTx(ctx, msg, initTx, sigs) + txm.retryTx(ctx, cancel, msg, initTx, sig) }() // Return signed tx, id, signature for use in simulation @@ -263,7 +263,16 @@ func (txm *Txm) buildTx(ctx context.Context, msg pendingTx, retryCount int) (sol // retryTx contains the logic for retrying the transaction, including exponential backoff and fee bumping. // Retries until context cancelled by timeout or called externally. // It uses handleRetry helper function to handle each retry attempt. -func (txm *Txm) retryTx(ctx context.Context, msg pendingTx, currentTx solanaGo.Transaction, sigs *txmutils.SignatureList) { +func (txm *Txm) retryTx(ctx context.Context, cancel context.CancelFunc, msg pendingTx, currentTx solanaGo.Transaction, sig solanaGo.Signature) { + // Initialize signature list with initialTx signature. This list will be used to add new signatures and track retry attempts. + sigs := &txmutils.SignatureList{} + sigs.Allocate() + if initSetErr := sigs.Set(0, sig); initSetErr != nil { + cancel() + txm.lggr.Errorw("failed to save initial signature in signature list", "error", initSetErr) + return + } + deltaT := 1 // initial delay in ms tick := time.After(0) bumpCount := 0 @@ -304,7 +313,7 @@ func (txm *Txm) retryTx(ctx context.Context, msg pendingTx, currentTx solanaGo.T wg.Add(1) go func(bump bool, count int, retryTx solanaGo.Transaction) { defer wg.Done() - txm.handleRetry(ctx, msg, bump, count, retryTx, sigs) + txm.handleRetry(ctx, cancel, msg, bump, count, retryTx, sigs) }(shouldBump, bumpCount, currentTx) } @@ -318,7 +327,7 @@ func (txm *Txm) retryTx(ctx context.Context, msg pendingTx, currentTx solanaGo.T } // handleRetry handles the logic for each retry attempt, including sending the transaction, updating signatures, and logging. -func (txm *Txm) handleRetry(ctx context.Context, msg pendingTx, bump bool, count int, retryTx solanaGo.Transaction, sigs *txmutils.SignatureList) { +func (txm *Txm) handleRetry(ctx context.Context, cancel context.CancelFunc, msg pendingTx, bump bool, count int, retryTx solanaGo.Transaction, sigs *txmutils.SignatureList) { // send retry transaction retrySig, err := txm.sendTx(ctx, &retryTx) if err != nil { @@ -333,7 +342,7 @@ func (txm *Txm) handleRetry(ctx context.Context, msg pendingTx, bump bool, count // if bump is true, update signature list and set new signature in space already allocated. if bump { - if err := txm.txs.AddSignature(msg.id, retrySig); err != nil { + if err := txm.txs.AddSignature(cancel, msg.id, retrySig); err != nil { txm.lggr.Warnw("error in adding retry transaction", "error", err, "id", msg.id) return } @@ -363,7 +372,7 @@ func (txm *Txm) handleRetry(ctx context.Context, msg pendingTx, bump bool, count } } -// confirm is a goroutine that continuously polls for transaction confirmations and handles rebroadcasts expired transactions if enabled. +// confirm is a goroutine that continuously polls for transaction confirmations. It also handles reorgs and expired transactions rebroadcasting. // The function runs until the chStop channel signals to stop. func (txm *Txm) confirm() { defer txm.done.Done() @@ -376,7 +385,7 @@ func (txm *Txm) confirm() { case <-ctx.Done(): return case <-tick: - // If no signatures to confirm and rebroadcast, we can break loop as there's nothing to process. + // If no signatures to confirm, we can break loop as there's nothing to process. if txm.InflightTxs() == 0 { break } @@ -395,10 +404,10 @@ func (txm *Txm) confirm() { } } -// processConfirmations checks the status of transaction signatures on-chain and updates our in-memory state accordingly. -// It splits the signatures into batches, retrieves their statuses with an RPC call, and processes each status accordingly. -// The function handles transitions, managing expiration, errors, and transitions between different states like broadcasted, processed, confirmed, and finalized. -// It also determines when to end polling based on the status of each signature cancelling the exponential retry. +// processConfirmations checks the on-chain status of transaction signatures and updates their in-memory state accordingly. +// The function splits the signatures into batches, retrieves their statuses using RPC calls, and processes each status. +// It handles various scenarios including expirations, errors, and state transitions (broadcasted, processed, confirmed, finalized). +// Additionally, it detects and manages re-orgs by removing or rebroadcasting transactions as necessary and determines when to end polling cancelling retry loops. func (txm *Txm) processConfirmations(ctx context.Context, client client.ReaderWriter) { sigsBatch, err := utils.BatchSplit(txm.txs.ListAllSigs(), MaxSigsToConfirm) if err != nil { // this should never happen @@ -428,8 +437,10 @@ func (txm *Txm) processConfirmations(ctx context.Context, client client.ReaderWr for j := 0; j < len(sortedRes); j++ { sig, status := sortedSigs[j], sortedRes[j] - // sig not found could mean invalid tx or not picked up yet, keep polling if status == nil { + // sig not found could mean invalid tx or not picked up yet, keep polling + // we also need to check if a re-org has occurred for this sig and handle it + txm.handleReorg(ctx, client, sig, status) txm.handleNotFoundSignatureStatus(sig) continue } @@ -443,19 +454,17 @@ func (txm *Txm) processConfirmations(ctx context.Context, client client.ReaderWr switch status.ConfirmationStatus { case rpc.ConfirmationStatusProcessed: // if signature is processed, keep polling for confirmed or finalized status + // we also need to check if a re-org has occurred for this sig and handle it + txm.handleReorg(ctx, client, sig, status) txm.handleProcessedSignatureStatus(sig) - continue case rpc.ConfirmationStatusConfirmed: // if signature is confirmed, keep polling for finalized status txm.handleConfirmedSignatureStatus(sig) - continue case rpc.ConfirmationStatusFinalized: // if signature is finalized, end polling txm.handleFinalizedSignatureStatus(sig) - continue default: txm.lggr.Warnw("unknown confirmation status", "signature", sig, "status", status.ConfirmationStatus) - continue } } }(i) @@ -499,6 +508,51 @@ func (txm *Txm) handleErrorSignatureStatus(sig solanaGo.Signature, status *rpc.S } } +// handleReorg detects and manages state regressions (re-orgs) for a given signature. +// +// A re-org occurs when the on-chain state of a signature regresses as follows: +// - Confirmed -> Processed || Not Found +// - Processed -> Not Found +// +// When a signature re-org is detected, the following steps are taken: +// - Remove the prior transaction, along with all associated signatures, and cancel the prior context. +// - Rebroadcast the prior transaction with a new blockhash and an updated compute unit price. +func (txm *Txm) handleReorg(ctx context.Context, client client.ReaderWriter, sig solanaGo.Signature, status *rpc.SignatureStatusesResult) { + // Determine if a re-org has occurred + sigState := txmutils.ConvertStatus(status) + txID, hasReorg := txm.txs.IsTxReorged(sig, sigState) + if !hasReorg { + return + } + + // At this point, we have detected a re-org. We need to rebroadcast the tx. + txm.lggr.Debugw("re-org detected for transaction", "txID", txID, "signature", sig) + pTx, err := txm.getPendingTx(txID) + if err != nil { + txm.lggr.Errorw("failed to get pending tx for rebroadcast", "txID", txID, "error", err) + return + } + + // The previous blockhash is invalid. We need to request a new one and rebroadcast the tx with it. + blockhash, err := client.LatestBlockhash(ctx) + if err != nil { + txm.lggr.Errorw("failed to getLatestBlockhash for rebroadcast", "error", err) + return + } + if blockhash == nil || blockhash.Value == nil { + txm.lggr.Errorw("nil pointer returned from getLatestBlockhash for rebroadcast") + return + } + + // Rebroadcasts tx with new blockhash after removing prior tx and signatures associated with it, cancelling prior ctx and updating compute unit price. + newSig, err := txm.rebroadcastWithGivenBlockhash(ctx, pTx, blockhash.Value.Blockhash, blockhash.Value.LastValidBlockHeight) + if err != nil { + return // logging handled inside the func + } + + txm.lggr.Debugw("re-orged tx was rebroadcasted successfully", "id", pTx.id, "newSig", newSig) +} + // handleProcessedSignatureStatus handles the case where a transaction signature is in the "processed" state on-chain. // It updates the transaction state in the local memory and checks if the confirmation timeout has been exceeded. // If the timeout is exceeded, it marks the transaction as errored. @@ -545,17 +599,16 @@ func (txm *Txm) handleFinalizedSignatureStatus(sig solanaGo.Signature) { // rebroadcastExpiredTxs attempts to rebroadcast all transactions that are in broadcasted state and have expired. // An expired tx is one where it's blockhash lastValidBlockHeight (last valid block number) is smaller than the current block height (block number). -// The function loops through all expired txes, rebroadcasts them with a new blockhash, and updates the lastValidBlockHeight. // If any error occurs during rebroadcast attempt, they are discarded, and the function continues with the next transaction. func (txm *Txm) rebroadcastExpiredTxs(ctx context.Context, client client.ReaderWriter) { - currBlock, err := client.GetLatestBlock(ctx) - if err != nil || currBlock == nil || currBlock.BlockHeight == nil { + blockHeight, err := client.GetLatestBlockHeight(ctx) + if err != nil || blockHeight == 0 { txm.lggr.Errorw("failed to get current block height", "error", err) return } // Get all expired broadcasted transactions at current block number. Safe to quit if no txes are found. - expiredBroadcastedTxes := txm.txs.ListAllExpiredBroadcastedTxs(*currBlock.BlockHeight) + expiredBroadcastedTxes := txm.txs.ListAllExpiredBroadcastedTxs(blockHeight) if len(expiredBroadcastedTxes) == 0 { return } @@ -570,33 +623,15 @@ func (txm *Txm) rebroadcastExpiredTxs(ctx context.Context, client client.ReaderW return } - // rebroadcast each expired tx after updating blockhash, lastValidBlockHeight and compute unit price (priority fee) - for _, tx := range expiredBroadcastedTxes { - txm.lggr.Debugw("transaction expired, rebroadcasting", "id", tx.id, "signature", tx.signatures, "lastValidBlockHeight", tx.lastValidBlockHeight, "currentBlockHeight", *currBlock.BlockHeight) - // Removes all signatures associated to prior tx and cancels context. - _, err := txm.txs.Remove(tx.id) + // rebroadcast each expired tx + for _, expiredTx := range expiredBroadcastedTxes { + txm.lggr.Debugw("transaction expired, rebroadcasting", "id", expiredTx.id, "signature", expiredTx.signatures, "lastValidBlockHeight", expiredTx.lastValidBlockHeight, "currentBlockHeight", blockHeight) + newSig, err := txm.rebroadcastWithGivenBlockhash(ctx, expiredTx, blockhash.Value.Blockhash, blockhash.Value.LastValidBlockHeight) if err != nil { - txm.lggr.Errorw("failed to remove expired transaction", "id", tx.id, "error", err) - continue - } - - tx.tx.Message.RecentBlockhash = blockhash.Value.Blockhash - tx.cfg.BaseComputeUnitPrice = txm.fee.BaseComputeUnitPrice() - rebroadcastTx := pendingTx{ - tx: tx.tx, - cfg: tx.cfg, - id: tx.id, // using same id in case it was set by caller and we need to maintain it. - lastValidBlockHeight: blockhash.Value.LastValidBlockHeight, - } - // call sendWithRetry directly to avoid enqueuing - _, _, _, sendErr := txm.sendWithRetry(ctx, rebroadcastTx) - if sendErr != nil { - stateTransitionErr := txm.txs.OnPrebroadcastError(tx.id, txm.cfg.TxRetentionTimeout(), txmutils.Errored, TxFailReject) - txm.lggr.Errorw("failed to rebroadcast transaction", "id", tx.id, "error", errors.Join(sendErr, stateTransitionErr)) - continue + continue // logging handled inside the func } - txm.lggr.Debugw("rebroadcast transaction sent", "id", tx.id) + txm.lggr.Debugw("expired tx was rebroadcasted successfully", "id", expiredTx.id, "newSig", newSig) } } @@ -920,6 +955,34 @@ func (txm *Txm) InflightTxs() int { return len(txm.txs.ListAllSigs()) } +// rebroadcastWithGivenBlockhash attempts to rebroadcast a pending tx with a new blockhash. +// Removes all signatures associated with the prior tx, cancels prior ctx, updates compute unit price and sets given blockhash for rebroadcasting. +// Calls sendWithRetry directly to avoid enqueuing the transaction. It logs the error when rebroadcast fails and returns the new signature when successful. +func (txm *Txm) rebroadcastWithGivenBlockhash(ctx context.Context, pTx pendingTx, blockhash solana.Hash, lastValidBlockHeight uint64) (solana.Signature, error) { + // Remove the previous tx from state + _, err := txm.txs.Remove(pTx.id) + if err != nil { + txm.lggr.Errorw("failed to remove tx", "id", pTx.id, "error", err) + return solana.Signature{}, err + } + + // Set new blockhash, lastValidBlockHeight and update compute unit price for rebroadcast + pTx.tx.Message.RecentBlockhash = blockhash + pTx.cfg.BaseComputeUnitPrice = txm.fee.BaseComputeUnitPrice() + pTx.lastValidBlockHeight = lastValidBlockHeight + + // call sendWithRetry directly to avoid enqueuing + _, _, newSig, sendErr := txm.sendWithRetry(ctx, pTx) + if sendErr != nil { + stateTransitionErr := txm.txs.OnPrebroadcastError(pTx.id, txm.cfg.TxRetentionTimeout(), txmutils.Errored, TxFailReject) + combinedErr := errors.Join(sendErr, stateTransitionErr) + txm.lggr.Errorw("failed to rebroadcast tx with new blockhash", "id", pTx.id, "error", combinedErr) + return solana.Signature{}, combinedErr + } + + return newSig, nil +} + // Close close service func (txm *Txm) Close() error { return txm.StopOnce("Txm", func() error { @@ -943,3 +1006,7 @@ func (txm *Txm) defaultTxConfig() txmutils.TxConfig { EstimateComputeUnitLimit: txm.cfg.EstimateComputeUnitLimit(), } } + +func (txm *Txm) getPendingTx(txID string) (pendingTx, error) { + return txm.txs.GetPendingTx(txID) +} diff --git a/pkg/solana/txm/txm_integration_test.go b/pkg/solana/txm/txm_integration_test.go index 154a42f6a..1f4275982 100644 --- a/pkg/solana/txm/txm_integration_test.go +++ b/pkg/solana/txm/txm_integration_test.go @@ -1,14 +1,19 @@ //go:build integration -package txm_test +package txm import ( + "bytes" "context" + "os" + "os/exec" + "path/filepath" "testing" "time" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/programs/system" + "github.com/gagliardetto/solana-go/rpc" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap/zapcore" @@ -24,7 +29,6 @@ import ( solanaClient "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" keyMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" ) @@ -65,8 +69,8 @@ func TestTxm_Integration_ExpirationRebroadcast(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { + tc := tc t.Parallel() ctx, client, txmInstance, senderPubKey, receiverPubKey, observer := setup(t, url, tc.txExpirationRebroadcast) @@ -104,8 +108,8 @@ func TestTxm_Integration_ExpirationRebroadcast(t *testing.T) { } // Verify rebroadcast logs - rebroadcastLogs := observer.FilterMessageSnippet("rebroadcast transaction sent").Len() - rebroadcastLogs2 := observer.FilterMessageSnippet("transaction expired, rebroadcasting").Len() + rebroadcastLogs := observer.FilterMessageSnippet("transaction expired, rebroadcasting").Len() + rebroadcastLogs2 := observer.FilterMessageSnippet("expired tx was rebroadcasted successfully").Len() if tc.expectRebroadcast { require.Equal(t, 1, rebroadcastLogs, "Expected rebroadcast log message not found") require.Equal(t, 1, rebroadcastLogs2, "Expected rebroadcast log message not found") @@ -117,7 +121,7 @@ func TestTxm_Integration_ExpirationRebroadcast(t *testing.T) { } } -func setup(t *testing.T, url string, txExpirationRebroadcast bool) (context.Context, *solanaClient.Client, *txm.Txm, solana.PublicKey, solana.PublicKey, *observer.ObservedLogs) { +func setup(t *testing.T, url string, txExpirationRebroadcast bool) (context.Context, *solanaClient.Client, *Txm, solana.PublicKey, solana.PublicKey, *observer.ObservedLogs) { ctx := tests.Context(t) // Generate sender and receiver keys and fund sender account @@ -139,14 +143,14 @@ func setup(t *testing.T, url string, txExpirationRebroadcast bool) (context.Cont // Set configs cfg := config.NewDefault() cfg.Chain.TxExpirationRebroadcast = &txExpirationRebroadcast - cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(10 * time.Second) // to get the finalized tx status + cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(1 * time.Minute) // to get the finalized tx status // Initialize the Solana client and TXM lggr, obs := logger.TestObserved(t, zapcore.DebugLevel) client, err := solanaClient.NewClient(url, cfg, 2*time.Second, lggr) require.NoError(t, err) loader := utils.NewLazyLoad(func() (solanaClient.ReaderWriter, error) { return client, nil }) - txmInstance := txm.NewTxm("localnet", loader, nil, cfg, mkey, lggr) + txmInstance := NewTxm("localnet", loader, nil, cfg, mkey, lggr) servicetest.Run(t, txmInstance) return ctx, client, txmInstance, senderPubKey, receiverPubKey, obs @@ -185,3 +189,185 @@ func createTransaction(ctx context.Context, t *testing.T, client *solanaClient.C return tx, lastValidBlockHeight } + +func TestTxm_Integration_Reorg(t *testing.T) { + t.Parallel() + t.Run("no reorg", func(t *testing.T) { + // Setup live validator and test environment + t.Parallel() + url := solanaClient.SetupLocalSolNode(t) + ctx, client, txmInstance, senderPubKey, receiverPubKey, observer := setup(t, url, true) + + // Record initial balance + initSenderBalance, err := client.Balance(ctx, senderPubKey) + require.NoError(t, err) + const amount = 1 * solana.LAMPORTS_PER_SOL + + // Create, enqueue and wait for tx finalization + txID := "no-reorg" + tx, lastValidBlockHeight := createTransaction(ctx, t, client, senderPubKey, receiverPubKey, amount, true) + require.NoError(t, txmInstance.Enqueue(ctx, "", tx, &txID, lastValidBlockHeight)) + require.Eventually(t, func() bool { + status, errGetStatus := txmInstance.GetTransactionStatus(ctx, txID) + if errGetStatus != nil { + return false + } + return status == types.Finalized + }, 60*time.Second, 1*time.Second, "Transaction should eventually reach Finalized status") + + // Verify that reorg was not detected and final balances are correct + reorgLogs := observer.FilterMessageSnippet("re-org detected for transaction").Len() + require.Equal(t, 0, reorgLogs, "Re-org should not occur") + finalSenderBalance, err := client.Balance(ctx, senderPubKey) + require.NoError(t, err) + finalReceiverBalance, err := client.Balance(ctx, receiverPubKey) + require.NoError(t, err) + require.Less(t, finalSenderBalance, initSenderBalance, "Sender balance should decrease") + require.Equal(t, amount, finalReceiverBalance, "Receiver should receive the transferred amount") + }) + + t.Run("confirmed reorg: previous tx is replaced and new one is finalized", func(t *testing.T) { + // Start live validator and setup test environment + t.Parallel() + ledgerDir := t.TempDir() + port := utils.MustRandomPort(t) + faucetPort := utils.MustRandomPort(t) + cmd, url := startValidator(t, ledgerDir, port, faucetPort, true) + ctx, cl, txmInstance, senderPubKey, receiverPubKey, obs := setup(t, url, true) + + // Back up the ledger after transferring funds + cleanLedgerBackupDir := t.TempDir() + require.NoError(t, copyDir(ledgerDir, cleanLedgerBackupDir)) + initSenderBalance, err := cl.Balance(ctx, senderPubKey) + require.NoError(t, err) + + // Create TX and wait for it to be confirmed + const amount = 1 * solana.LAMPORTS_PER_SOL + txID := "reorg-test-tx" + tx, lastValidBlockHeight := createTransaction(ctx, t, cl, senderPubKey, receiverPubKey, amount, true) + require.NoError(t, txmInstance.Enqueue(ctx, "", tx, &txID, lastValidBlockHeight)) + require.Eventually(t, func() bool { + status, errGetStatus := txmInstance.GetTransactionStatus(ctx, txID) + if errGetStatus != nil { + return false + } + if status == types.Unconfirmed { + pTx, errPtx := txmInstance.getPendingTx(txID) + if errPtx != nil || len(pTx.signatures) == 0 { + return false + } + + sigStatus, errStat := cl.SignatureStatuses(ctx, pTx.signatures) + if errStat != nil || len(sigStatus) == 0 || sigStatus[0] == nil { + return false + } + return sigStatus[0].ConfirmationStatus == rpc.ConfirmationStatusConfirmed + } + return false + }, 60*time.Second, 1*time.Second, "Transaction should reach Confirmed status") + + // Simulate reorg: kill current validator and restart validator with backuped ledger before the tx. + // we want ledger as provided, omit --reset + require.NoError(t, cmd.Process.Kill()) + _ = cmd.Wait() + require.NoError(t, os.RemoveAll(ledgerDir)) + require.NoError(t, copyDir(cleanLedgerBackupDir, ledgerDir)) + startValidator(t, ledgerDir, port, faucetPort, false) + + // Check tx is not finalized yet and reorg is detected + status, errGetStatus := txmInstance.GetTransactionStatus(ctx, txID) + require.NoError(t, errGetStatus) + require.NotEqual(t, types.Finalized, status, "tx should not be finalized after reorg") + reorgLogs := obs.FilterMessageSnippet("re-org detected for transaction").Len() + require.Equal(t, reorgLogs, 1, "Re-org should be detected") + rebroadcastReorgLogs := obs.FilterMessageSnippet("re-orged tx was rebroadcasted successfully").Len() + require.Equal(t, rebroadcastReorgLogs, 1, "re-org tx should be rebroadcasted with new blockhash") + + // Wait rebroadcasted tx to be finalized and check final balances + require.Eventually(t, func() bool { + finalStatus, errAgain := txmInstance.GetTransactionStatus(ctx, txID) + if errAgain != nil { + return false + } + return finalStatus == types.Finalized + }, 120*time.Second, 5*time.Second, "tx should finalize again after reorg handling") + finalSenderBalance, err := cl.Balance(ctx, senderPubKey) + require.NoError(t, err) + finalReceiverBalance, err := cl.Balance(ctx, receiverPubKey) + require.NoError(t, err) + require.Less(t, finalSenderBalance, initSenderBalance, "Sender balance should decrease after re-finalization") + require.Equal(t, amount, finalReceiverBalance, "Receiver should receive transferred amount after re-finalization") + status, errGetStatus = txmInstance.GetTransactionStatus(ctx, txID) + require.NoError(t, errGetStatus) + require.Equal(t, types.Finalized, status, "tx should be finalized after reorg") + }) +} + +// startValidator starts a local solana-test-validator and return the cmd to control it. +func startValidator( + t *testing.T, + ledgerDir, port, faucetPort string, + reset bool, +) (*exec.Cmd, string) { + t.Helper() + + args := []string{ + "--rpc-port", port, + "--faucet-port", faucetPort, + "--ledger", ledgerDir, + } + if reset { + args = append([]string{"--reset"}, args...) + } + + cmd := exec.Command("solana-test-validator", args...) + + var stdErr, stdOut bytes.Buffer + cmd.Stderr = &stdErr + cmd.Stdout = &stdOut + + require.NoError(t, cmd.Start(), "failed to start solana-test-validator") + + // The RPC URL + url := "http://127.0.0.1:" + port + + // Ensure validator is killed after the test finishes + t.Cleanup(func() { + _ = cmd.Process.Kill() + _ = cmd.Wait() + }) + + // Wait until it's healthy + client := rpc.New(url) + require.Eventually(t, func() bool { + out, err := client.GetHealth(context.Background()) + return err == nil && out == rpc.HealthOk + }, 30*time.Second, 1*time.Second, "Validator should become healthy") + + return cmd, url +} + +// copyDir copies the directory tree. +func copyDir(src, dst string) error { + return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + relPath, err := filepath.Rel(src, path) + if err != nil { + return err + } + dstPath := filepath.Join(dst, relPath) + + if info.IsDir() { + return os.MkdirAll(dstPath, info.Mode()) + } + + data, err := os.ReadFile(path) + if err != nil { + return err + } + + return os.WriteFile(dstPath, data, info.Mode()) + }) +} diff --git a/pkg/solana/txm/txm_internal_test.go b/pkg/solana/txm/txm_internal_test.go index 15e4631a3..0ef477f95 100644 --- a/pkg/solana/txm/txm_internal_test.go +++ b/pkg/solana/txm/txm_internal_test.go @@ -5,6 +5,7 @@ package txm import ( "context" "errors" + "fmt" "math/big" "sync" "testing" @@ -510,8 +511,8 @@ func TestTxm(t *testing.T) { mc.On("SendTx", mock.Anything, tx).Panic("SendTx should not be called anymore").Maybe() }) - // tx passes sim, shows processed, moves to nil (timeout should cleanup) - t.Run("fail_confirm_processedToNil", func(t *testing.T) { + // tx passes sim, gets processed, regresses to not found, gets rebroadcasted by re-org logic and stuck on processed. Eventually cleaned up by timeout. + t.Run("reorged_tx_stucked_on_processed_is_eventually_cleaned_up", func(t *testing.T) { tx, signed := getTx(t, 8, mkey) sig := randomSignature(t) retry0 := randomSignature(t) @@ -530,10 +531,21 @@ func TestTxm(t *testing.T) { wg.Done() }).Return(&rpc.SimulateTransactionResult{}, nil).Once() + mc.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(2000), + }, + }, nil).Once() + // handle signature status calls (initial stays processed => nil, others don't exist) count := 0 statuses[sig] = func() (out *rpc.SignatureStatusesResult) { defer func() { count++ }() + if count > 4 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusProcessed, + } + } if count > 2 { return nil @@ -557,7 +569,7 @@ func TestTxm(t *testing.T) { prom.assertEqual(t) _, err := txm.GetTransactionStatus(ctx, testTxID) - require.Error(t, err) // transaction cleared from storage after finalized should not return status + require.Error(t, err) // transaction cleared from storage // panic if sendTx called after context cancelled mc.On("SendTx", mock.Anything, tx).Panic("SendTx should not be called anymore").Maybe() @@ -810,7 +822,6 @@ func TestTxm_disabled_confirm_timeout_with_retention(t *testing.T) { return out }, nil, ) - t.Run("happyPath", func(t *testing.T) { // Test tx is not discarded due to confirm timeout and tracked to finalization // use unique val across tests to avoid collision during mocking @@ -1223,7 +1234,7 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { setupTxmTest := func( txExpirationRebroadcast bool, latestBlockhashFunc func() (*rpc.GetLatestBlockhashResult, error), - getLatestBlockFunc func() (*rpc.GetBlockResult, error), + getLatestBlockHeightFunc func() (uint64, error), sendTxFunc func() (solana.Signature, error), statuses map[solana.Signature]func() *rpc.SignatureStatusesResult, ) (*Txm, *mocks.ReaderWriter, *keyMocks.SimpleKeystore) { @@ -1237,10 +1248,10 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { }, ).Maybe() } - if getLatestBlockFunc != nil { - mc.On("GetLatestBlock", mock.Anything).Return( - func(_ context.Context) (*rpc.GetBlockResult, error) { - return getLatestBlockFunc() + if getLatestBlockHeightFunc != nil { + mc.On("GetLatestBlockHeight", mock.Anything).Return( + func(_ context.Context) (uint64, error) { + return getLatestBlockHeightFunc() }, ).Maybe() } @@ -1289,11 +1300,8 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} // Mock getLatestBlock to return a value greater than 0 for blockHeight - getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { - val := uint64(1500) - return &rpc.GetBlockResult{ - BlockHeight: &val, - }, nil + getLatestBlockHeightFunc := func() (uint64, error) { + return 1500, nil } rebroadcastCount := 0 @@ -1339,7 +1347,7 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { } } - txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockHeightFunc, sendTxFunc, statuses) tx, _ := getTx(t, 0, mkey) txID := "test-rebroadcast" @@ -1408,11 +1416,8 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} // Mock getLatestBlock to return a value greater than 0 - getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { - val := uint64(1500) - return &rpc.GetBlockResult{ - BlockHeight: &val, - }, nil + getLatestBlockHeightFunc := func() (uint64, error) { + return 1500, nil } // Mock LatestBlockhash to return an invalid blockhash in the first 2 attempts to rebroadcast. @@ -1464,7 +1469,7 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { } } - txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockHeightFunc, sendTxFunc, statuses) tx, _ := getTx(t, 0, mkey) txID := "test-rebroadcast" lastValidBlockHeight := uint64(100) // lastValidBlockHeight is smaller than blockHeight @@ -1493,11 +1498,8 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { } // Mock getLatestBlock to return a value greater than 0 - getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { - val := uint64(1500) - return &rpc.GetBlockResult{ - BlockHeight: &val, - }, nil + getLatestBlockHeightFunc := func() (uint64, error) { + return 1500, nil } rebroadcastCount := 0 @@ -1530,7 +1532,7 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { return out } - txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockHeightFunc, sendTxFunc, statuses) tx, _ := getTx(t, 0, mkey) txID := "test-confirmed-before-rebroadcast" lastValidBlockHeight := uint64(1500) // original lastValidBlockHeight is valid @@ -1556,11 +1558,8 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { // To force rebroadcast, first call needs to be smaller than blockHeight // following rebroadcast call will go through because lastValidBlockHeight will be bigger than blockHeight - getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { - val := uint64(1500) - return &rpc.GetBlockResult{ - BlockHeight: &val, - }, nil + getLatestBlockHeightFunc := func() (uint64, error) { + return 1500, nil } rebroadcastCount := 0 @@ -1590,7 +1589,7 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { return nil } - txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockHeightFunc, sendTxFunc, statuses) tx, _ := getTx(t, 0, mkey) txID := "test-rebroadcast-error" lastValidBlockHeight := uint64(100) // lastValidBlockHeight is smaller than blockHeight @@ -1610,3 +1609,180 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { require.Equal(t, 1, rebroadcastCount) }) } + +func TestTxm_OnReorg(t *testing.T) { + t.Parallel() + estimator := "fixed" + id := "mocknet-" + estimator + "-" + uuid.NewString() + cfg := config.NewDefault() + cfg.Chain.FeeEstimatorMode = &estimator + cfg.Chain.TxConfirmTimeout = relayconfig.MustNewDuration(5 * time.Second) + // Enable retention to keep transactions after finality and be able to check their statuses. + cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(10 * time.Second) + lggr := logger.Test(t) + ctx := tests.Context(t) + + // Helper that sets up a Txm and mocks. + setupTxmTest := func( + txExpirationRebroadcast bool, + latestBlockhashFunc func() (*rpc.GetLatestBlockhashResult, error), + getLatestBlockHeightFunc func() (uint64, error), + sendTxFunc func() (solana.Signature, error), + statuses map[solana.Signature]func() *rpc.SignatureStatusesResult, + ) (*Txm, *mocks.ReaderWriter, *keyMocks.SimpleKeystore) { + cfg.Chain.TxExpirationRebroadcast = &txExpirationRebroadcast + + mc := mocks.NewReaderWriter(t) + if latestBlockhashFunc != nil { + mc.On("LatestBlockhash", mock.Anything).Return(func(_ context.Context) (*rpc.GetLatestBlockhashResult, error) { + return latestBlockhashFunc() + }).Maybe() + } + if getLatestBlockHeightFunc != nil { + mc.On("GetLatestBlockHeight", mock.Anything).Return(func(_ context.Context) (uint64, error) { + return getLatestBlockHeightFunc() + }).Maybe() + } + if sendTxFunc != nil { + mc.On("SendTx", mock.Anything, mock.Anything).Return(func(_ context.Context, _ *solana.Transaction) (solana.Signature, error) { + return sendTxFunc() + }).Maybe() + } + mc.On("SimulateTx", mock.Anything, mock.Anything, mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe() + if statuses != nil { + mc.On("SignatureStatuses", mock.Anything, mock.AnythingOfType("[]solana.Signature")).Return( + func(_ context.Context, sigs []solana.Signature) ([]*rpc.SignatureStatusesResult, error) { + var out []*rpc.SignatureStatusesResult + for _, sig := range sigs { + getStatus, exists := statuses[sig] + if !exists { + out = append(out, nil) + } else { + out = append(out, getStatus()) + } + } + return out, nil + }, + ).Maybe() + } + + mkey := keyMocks.NewSimpleKeystore(t) + mkey.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil) + + loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) + txm := NewTxm(id, loader, nil, cfg, mkey, lggr) + require.NoError(t, txm.Start(ctx)) + t.Cleanup(func() { require.NoError(t, txm.Close()) }) + + return txm, mc, mkey + } + + prom := soltxmProm{id: id} // track Prometheus metrics across runs + + type scenario struct { + name string + initialSigFrom rpc.ConfirmationStatusType // e.g. "processed", "confirmed" + } + + tests := []scenario{ + { + name: "confirmed => re-org => new tx finalizes", + initialSigFrom: rpc.ConfirmationStatusConfirmed, + }, + { + name: "processed => re-org => new tx finalizes", + initialSigFrom: rpc.ConfirmationStatusProcessed, + }, + } + + for _, sc := range tests { + t.Run(sc.name, func(t *testing.T) { + // mock latest blockhash. Re-orged tx needs to be broadcasted with a new blockhash + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + Blockhash: solana.HashFromBytes([]byte{2}), + LastValidBlockHeight: 2001, + }, + }, nil + } + + initialSig := randomSignature(t) + var initialTxCtxStopped bool + + retrySig := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + // We will return the initialSig until the re-org happens + // After that, we'll return the retrySig as the prior tx is replaced and context cancelled + if !initialTxCtxStopped { + return initialSig, nil + } + return retrySig, nil + } + + // Mock the on-chain status of the initial tx + var initialStatusCallCount int + var wg sync.WaitGroup + wg.Add(1) + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{ + initialSig: func() *rpc.SignatureStatusesResult { + defer func() { initialStatusCallCount++ }() + if initialStatusCallCount < 2 { + // keep returning sc.initialSigFrom (e.g. Confirmed or Processed) + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: sc.initialSigFrom, + } + } + + // simulate re-org => NotFound + initialTxCtxStopped = true + wg.Done() + return nil + }, + } + + // Mock the on-chain status of the re-orged tx. It will eventually finalize + var retryStatusCallCount int + wg.Add(1) + statuses[retrySig] = func() *rpc.SignatureStatusesResult { + defer func() { retryStatusCallCount++ }() + switch retryStatusCallCount { + case 0: + return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusProcessed} + case 1, 2: + return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusConfirmed} + default: + wg.Done() + return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusFinalized} + } + } + + txm, _, mkey := setupTxmTest(false, latestBlockhashFunc, nil, sendTxFunc, statuses) + + // Enqueue our transaction to the Txm + tx, _ := getTx(t, 0, mkey) + txID := fmt.Sprintf("reorg-from-%s", sc.initialSigFrom) + lastValidBlockHeight := uint64(100) + require.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) + + // Wait for the states to move from initial => re-org => new => finalized + // Wait txm get the final states or timeouts of the transactions. + wg.Wait() + waitFor(t, txm.cfg.TxConfirmTimeout(), txm, prom, empty) + + // check prom metric increments + if sc.initialSigFrom == rpc.ConfirmationStatusConfirmed { + prom.confirmed++ // in case initial tx was confirmed + } + // re-orged tx should always be confirmed and finalized + prom.confirmed++ + prom.finalized++ + prom.assertEqual(t) + + // Ensure the TX is Finalized in memory + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + }) + } +} diff --git a/pkg/solana/txm/utils/utils.go b/pkg/solana/txm/utils/utils.go index 7f3ffb9e2..58c79e0c2 100644 --- a/pkg/solana/txm/utils/utils.go +++ b/pkg/solana/txm/utils/utils.go @@ -66,7 +66,7 @@ func (s statuses) Swap(i, j int) { } func (s statuses) Less(i, j int) bool { - return convertStatus(s.res[i]) > convertStatus(s.res[j]) // returns list with highest first + return ConvertStatus(s.res[i]) > ConvertStatus(s.res[j]) // returns list with highest first } func SortSignaturesAndResults(sigs []solana.Signature, res []*rpc.SignatureStatusesResult) ([]solana.Signature, []*rpc.SignatureStatusesResult, error) { @@ -82,7 +82,7 @@ func SortSignaturesAndResults(sigs []solana.Signature, res []*rpc.SignatureStatu return s.sigs, s.res, nil } -func convertStatus(res *rpc.SignatureStatusesResult) TxState { +func ConvertStatus(res *rpc.SignatureStatusesResult) TxState { if res == nil { return NotFound }