Skip to content

Commit

Permalink
CallMany per txn gran (#13399)
Browse files Browse the repository at this point in the history
closes #13220

---------

Co-authored-by: JkLondon <[email protected]>
  • Loading branch information
JkLondon and JkLondon authored Jan 16, 2025
1 parent 9c76900 commit 19cb3b0
Show file tree
Hide file tree
Showing 2 changed files with 388 additions and 83 deletions.
281 changes: 237 additions & 44 deletions turbo/jsonrpc/trace_adhoc.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/erigontech/erigon-lib/kv/rawdbv3"
"github.com/erigontech/erigon/turbo/snapshotsync/freezeblocks"
"math"
"strings"

Expand Down Expand Up @@ -817,7 +819,7 @@ func (api *TraceAPIImpl) ReplayTransaction(ctx context.Context, txHash libcommon
}

var isBorStateSyncTxn bool
blockNum, _, ok, err := api.txnLookup(ctx, tx, txHash)
blockNum, txNum, ok, err := api.txnLookup(ctx, tx, txHash)
if err != nil {
return nil, err
}
Expand All @@ -843,30 +845,31 @@ func (api *TraceAPIImpl) ReplayTransaction(ctx context.Context, txHash libcommon
isBorStateSyncTxn = true
}

block, err := api.blockByNumberWithSenders(ctx, tx, blockNum)
header, err := api.headerByRPCNumber(ctx, rpc.BlockNumber(blockNum), tx)
if err != nil {
return nil, err
}
if block == nil {
return nil, nil

txNumsReader := rawdbv3.TxNums.WithCustomReadTxNumFunc(freezeblocks.ReadTxNumFuncFromBlockReader(ctx, api._blockReader))

txNumMin, err := txNumsReader.Min(tx, blockNum)
if err != nil {
return nil, err
}

var txnIndex int
for idx := 0; idx < block.Transactions().Len() && !isBorStateSyncTxn; idx++ {
txn := block.Transactions()[idx]
if txn.Hash() == txHash {
txnIndex = idx
break
}
if txNumMin+2 > txNum {
return nil, fmt.Errorf("uint underflow txnums error txNum: %d, txNumMin: %d, blockNum: %d", txNum, txNumMin, blockNum)
}

var txnIndex = int(txNum - txNumMin - 2)

if isBorStateSyncTxn {
txnIndex = block.Transactions().Len()
txnIndex = -1
}

signer := types.MakeSigner(chainConfig, blockNum, block.Time())
signer := types.MakeSigner(chainConfig, blockNum, header.Time)
// Returns an array of trace arrays, one trace array for each transaction
traces, _, err := api.callManyTransactions(ctx, tx, block, traceTypes, txnIndex, *gasBailOut, signer, chainConfig, traceConfig)
trace, _, err := api.callTransaction(ctx, tx, header, traceTypes, txnIndex, *gasBailOut, signer, chainConfig, traceConfig)
if err != nil {
return nil, err
}
Expand All @@ -886,25 +889,18 @@ func (api *TraceAPIImpl) ReplayTransaction(ctx context.Context, txHash libcommon
}
result := &TraceCallResult{}

for txno, trace := range traces {
// We're only looking for a specific transaction
if txno == txnIndex {
result.Output = trace.Output
if traceTypeTrace {
result.Trace = trace.Trace
}
if traceTypeStateDiff {
result.StateDiff = trace.StateDiff
}
if traceTypeVmTrace {
result.VmTrace = trace.VmTrace
}

return trace, nil
}
result.Output = trace.Output
if traceTypeTrace {
result.Trace = trace.Trace
}
if traceTypeStateDiff {
result.StateDiff = trace.StateDiff
}
if traceTypeVmTrace {
result.VmTrace = trace.VmTrace
}

return result, nil
return trace, nil
}

func (api *TraceAPIImpl) ReplayBlockTransactions(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash, traceTypes []string, gasBailOut *bool, traceConfig *config.TraceConfig) ([]*TraceCallResult, error) {
Expand Down Expand Up @@ -950,7 +946,7 @@ func (api *TraceAPIImpl) ReplayBlockTransactions(ctx context.Context, blockNrOrH

signer := types.MakeSigner(chainConfig, blockNumber, block.Time())
// Returns an array of trace arrays, one trace array for each transaction
traces, _, err := api.callManyTransactions(ctx, tx, block, traceTypes, -1 /* all txn indices */, *gasBailOut, signer, chainConfig, traceConfig)
traces, _, err := api.callBlock(ctx, tx, block, traceTypes, *gasBailOut, signer, chainConfig, traceConfig)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1212,14 +1208,14 @@ func (api *TraceAPIImpl) CallMany(ctx context.Context, calls json.RawMessage, pa
cachedWriter := state.NewCachedWriter(noop, stateCache)
ibs := state.New(cachedReader)

return api.doCallMany(ctx, dbtx, stateReader, stateCache, cachedWriter, ibs,
msgs, callParams, parentNrOrHash, nil, true /* gasBailout */, -1 /* all txn indices */, traceConfig)
return api.doCallBlock(ctx, dbtx, stateReader, stateCache, cachedWriter, ibs,
msgs, callParams, parentNrOrHash, nil, true /* gasBailout */, traceConfig)
}

func (api *TraceAPIImpl) doCallMany(ctx context.Context, dbtx kv.Tx, stateReader state.StateReader,
func (api *TraceAPIImpl) doCallBlock(ctx context.Context, dbtx kv.Tx, stateReader state.StateReader,
stateCache *shards.StateCache, cachedWriter state.StateWriter, ibs *state.IntraBlockState,
msgs []types.Message, callParams []TraceCallParam,
parentNrOrHash *rpc.BlockNumberOrHash, header *types.Header, gasBailout bool, txIndexNeeded int,
parentNrOrHash *rpc.BlockNumberOrHash, header *types.Header, gasBailout bool,
traceConfig *config.TraceConfig,
) ([]*TraceCallResult, error) {
chainConfig, err := api.chainConfig(ctx, dbtx)
Expand Down Expand Up @@ -1299,7 +1295,7 @@ func (api *TraceAPIImpl) doCallMany(ctx context.Context, dbtx kv.Tx, stateReader

traceResult := &TraceCallResult{Trace: []*ParityTrace{}, TransactionHash: args.txHash}
vmConfig := vm.Config{}
if (traceTypeTrace && (txIndexNeeded == -1 || txIndex == txIndexNeeded)) || traceTypeVmTrace {
if traceTypeTrace || traceTypeVmTrace {
var ot OeTracer
ot.config, err = parseOeTracerConfig(traceConfig)
if err != nil {
Expand All @@ -1308,7 +1304,7 @@ func (api *TraceAPIImpl) doCallMany(ctx context.Context, dbtx kv.Tx, stateReader
ot.compat = api.compatibility
ot.r = traceResult
ot.idx = []string{fmt.Sprintf("%d-", txIndex)}
if traceTypeTrace && (txIndexNeeded == -1 || txIndex == txIndexNeeded) {
if traceTypeTrace {
ot.traceAddr = []int{}
}
if traceTypeVmTrace {
Expand Down Expand Up @@ -1389,7 +1385,11 @@ func (api *TraceAPIImpl) doCallMany(ctx context.Context, dbtx kv.Tx, stateReader
return nil, err
}
}
sd.CompareStates(initialIbs, ibs)
if sd != nil {
if err = sd.CompareStates(initialIbs, ibs); err != nil {
return nil, err
}
}
if err = ibs.CommitBlock(chainRules, cachedWriter); err != nil {
return nil, err
}
Expand All @@ -1407,16 +1407,209 @@ func (api *TraceAPIImpl) doCallMany(ctx context.Context, dbtx kv.Tx, stateReader
traceResult.Trace = []*ParityTrace{}
}
results = append(results, traceResult)
// When txIndexNeeded is not -1, we are tracing specific transaction in the block and not the entire block, so we stop after we've traced
// the required transaction
if txIndexNeeded != -1 && txIndex == txIndexNeeded {
break
}
}

return results, nil
}

func (api *TraceAPIImpl) doCall(ctx context.Context, dbtx kv.Tx, stateReader state.StateReader,
stateCache *shards.StateCache, cachedWriter state.StateWriter, ibs *state.IntraBlockState,
msg types.Message, callParam TraceCallParam,
parentNrOrHash *rpc.BlockNumberOrHash, header *types.Header, gasBailout bool, txIndex int,
traceConfig *config.TraceConfig,
) (*TraceCallResult, error) {
chainConfig, err := api.chainConfig(ctx, dbtx)
if err != nil {
return nil, err
}
engine := api.engine()

if parentNrOrHash == nil {
var num = rpc.LatestBlockNumber
parentNrOrHash = &rpc.BlockNumberOrHash{BlockNumber: &num}
}
blockNumber, hash, _, err := rpchelper.GetBlockNumber(ctx, *parentNrOrHash, dbtx, api._blockReader, api.filters)
if err != nil {
return nil, err
}
noop := state.NewNoopWriter()

parentHeader, err := api.headerByRPCNumber(ctx, rpc.BlockNumber(blockNumber), dbtx)
if err != nil {
return nil, err
}
if parentHeader == nil {
return nil, fmt.Errorf("parent header %d(%x) not found", blockNumber, hash)
}

// Setup context so it may be cancelled the call has completed
// or, in case of unmetered gas, setup a context with a timeout.
var cancel context.CancelFunc
if api.evmCallTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, api.evmCallTimeout)
} else {
ctx, cancel = context.WithCancel(ctx)
}

// Make sure the context is cancelled when the call has completed
// this makes sure resources are cleaned up.
defer cancel()

useParent := false
if header == nil {
header = parentHeader
useParent = true
}

var baseTxNum uint64
historicalStateReader, isHistoricalStateReader := stateReader.(state.HistoricalStateReader)
if isHistoricalStateReader {
baseTxNum = historicalStateReader.GetTxNum()
}

blockCtx := transactions.NewEVMBlockContext(engine, header, parentNrOrHash.RequireCanonical, dbtx, api._blockReader, chainConfig)

if isHistoricalStateReader {
historicalStateReader.SetTxNum(baseTxNum + uint64(txIndex))
}
if err := libcommon.Stopped(ctx.Done()); err != nil {
return nil, err
}

var traceTypeTrace, traceTypeStateDiff, traceTypeVmTrace bool
args := callParam
for _, traceType := range args.traceTypes {
switch traceType {
case TraceTypeTrace:
traceTypeTrace = true
case TraceTypeStateDiff:
traceTypeStateDiff = true
case TraceTypeVmTrace:
traceTypeVmTrace = true
default:
return nil, fmt.Errorf("unrecognized trace type: %s", traceType)
}
}

traceResult := &TraceCallResult{Trace: []*ParityTrace{}, TransactionHash: args.txHash}
vmConfig := vm.Config{}
if traceTypeTrace || traceTypeVmTrace {
var ot OeTracer
ot.config, err = parseOeTracerConfig(traceConfig)
if err != nil {
return nil, err
}
ot.compat = api.compatibility
ot.r = traceResult
ot.idx = []string{fmt.Sprintf("%d-", txIndex)}
if traceTypeTrace {
ot.traceAddr = []int{}
}
if traceTypeVmTrace {
traceResult.VmTrace = &VmTrace{Ops: []*VmTraceOp{}}
}
vmConfig.Debug = true
vmConfig.Tracer = &ot
}

if useParent {
blockCtx.GasLimit = math.MaxUint64
blockCtx.MaxGasLimit = true
}

// Clone the state cache before applying the changes for diff after transaction execution, clone is discarded
var cloneReader state.StateReader
var sd *StateDiff
if traceTypeStateDiff {
cloneCache := stateCache.Clone()
cloneReader = state.NewCachedReader(stateReader, cloneCache)
//cloneReader = stateReader
if isHistoricalStateReader {
historicalStateReader.SetTxNum(baseTxNum + uint64(txIndex))
}
sdMap := make(map[libcommon.Address]*StateDiffAccount)
traceResult.StateDiff = sdMap
sd = &StateDiff{sdMap: sdMap}
}

ibs.Reset()
var finalizeTxStateWriter state.StateWriter
if sd != nil {
finalizeTxStateWriter = sd
} else {
finalizeTxStateWriter = noop
}

var txFinalized bool
var execResult *evmtypes.ExecutionResult
if args.isBorStateSyncTxn {
txFinalized = true
var stateSyncEvents []*types.Message
stateSyncEvents, err = api.stateSyncEvents(ctx, dbtx, header.Hash(), blockNumber, chainConfig)
if err != nil {
return nil, err
}

execResult, err = tracer.TraceBorStateSyncTxnTraceAPI(
ctx,
&vmConfig,
chainConfig,
ibs,
finalizeTxStateWriter,
blockCtx,
header.Hash(),
header.Number.Uint64(),
header.Time,
stateSyncEvents,
)
} else {
ibs.SetTxContext(txIndex)
txCtx := core.NewEVMTxContext(msg)
evm := vm.NewEVM(blockCtx, txCtx, ibs, chainConfig, vmConfig)
gp := new(core.GasPool).AddGas(msg.Gas()).AddBlobGas(msg.BlobGas())

execResult, err = core.ApplyMessage(evm, msg, gp, true /* refunds */, gasBailout /*gasBailout*/)
}
if err != nil {
return nil, fmt.Errorf("first run for txIndex %d error: %w", txIndex, err)
}

chainRules := chainConfig.Rules(blockCtx.BlockNumber, blockCtx.Time)
traceResult.Output = libcommon.CopyBytes(execResult.ReturnData)
if traceTypeStateDiff {
initialIbs := state.New(cloneReader)
if !txFinalized {
if err = ibs.FinalizeTx(chainRules, sd); err != nil {
return nil, err
}
}

if sd != nil {
if err = sd.CompareStates(initialIbs, ibs); err != nil {
return nil, err
}
}

if err = ibs.CommitBlock(chainRules, cachedWriter); err != nil {
return nil, err
}
} else {
if !txFinalized {
if err = ibs.FinalizeTx(chainRules, noop); err != nil {
return nil, err
}
}
if err = ibs.CommitBlock(chainRules, cachedWriter); err != nil {
return nil, err
}
}
if !traceTypeTrace {
traceResult.Trace = []*ParityTrace{}
}

return traceResult, nil
}

// RawTransaction implements trace_rawTransaction.
func (api *TraceAPIImpl) RawTransaction(ctx context.Context, txHash libcommon.Hash, traceTypes []string) ([]interface{}, error) {
var stub []interface{}
Expand Down
Loading

0 comments on commit 19cb3b0

Please sign in to comment.