From 8d51b81989c1c05864c87bb46bb4567be92cec83 Mon Sep 17 00:00:00 2001 From: Darioush Jalali Date: Tue, 1 Oct 2024 14:45:55 -0700 Subject: [PATCH] trie_prefetcher: alternate structure --- core/state/prefetcher_database.go | 134 ++++++++ core/state/statedb.go | 9 +- core/state/trie_prefetcher.go | 524 ++++++++--------------------- core/state/trie_prefetcher_test.go | 9 +- 4 files changed, 285 insertions(+), 391 deletions(-) create mode 100644 core/state/prefetcher_database.go diff --git a/core/state/prefetcher_database.go b/core/state/prefetcher_database.go new file mode 100644 index 0000000000..e4bac44e42 --- /dev/null +++ b/core/state/prefetcher_database.go @@ -0,0 +1,134 @@ +// (c) 2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package state + +import ( + "fmt" + "sync" + + "github.com/ava-labs/coreth/core/types" + "github.com/ava-labs/coreth/utils" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/log" +) + +// ForPrefetchingOnly returns a new database that is only suitable for prefetching +// operations. It will not be safe to use for any other operations. +// Close must be called on the returned database when it is no longer needed +// to wait on all spawned goroutines. +func (*cachingDB) ForPrefetchingOnly(db Database, maxConcurrency int) Database { + return newPrefetcherDatabase(db, maxConcurrency) +} + +type prefetcherDatabase struct { + Database + + maxConcurrency int + workers *utils.BoundedWorkers +} + +func newPrefetcherDatabase(db Database, maxConcurrency int) *prefetcherDatabase { + return &prefetcherDatabase{ + Database: db, + maxConcurrency: maxConcurrency, + workers: utils.NewBoundedWorkers(maxConcurrency), + } +} + +func (p *prefetcherDatabase) OpenTrie(root common.Hash) (Trie, error) { + t, err := p.Database.OpenTrie(root) + return newPrefetcherTrie(p, t), err +} + +func (p *prefetcherDatabase) OpenStorageTrie(stateRoot common.Hash, address common.Address, root common.Hash, trie Trie) (Trie, error) { + t, err := p.Database.OpenStorageTrie(stateRoot, address, root, trie) + return newPrefetcherTrie(p, t), err +} + +func (p *prefetcherDatabase) CopyTrie(t Trie) Trie { + switch t := t.(type) { + case *prefetcherTrie: + return newPrefetcherTrie(p, t.getCopy()) + default: + panic(fmt.Errorf("unknown trie type %T", t)) + } +} + +func (p *prefetcherDatabase) Close() { + p.workers.Wait() +} + +type prefetcherTrie struct { + p *prefetcherDatabase + + Trie + copyLock sync.Mutex + + copies chan Trie + wg sync.WaitGroup +} + +func newPrefetcherTrie(p *prefetcherDatabase, t Trie) *prefetcherTrie { + prefetcher := &prefetcherTrie{ + p: p, + Trie: t, + copies: make(chan Trie, p.maxConcurrency), + } + prefetcher.copies <- prefetcher.getCopy() + return prefetcher +} + +func (p *prefetcherTrie) Wait() { + p.wg.Wait() +} + +func (p *prefetcherTrie) getCopy() Trie { + select { + case copy := <-p.copies: + return copy + default: + p.copyLock.Lock() + defer p.copyLock.Unlock() + return p.p.Database.CopyTrie(p.Trie) + } +} + +func (p *prefetcherTrie) putCopy(copy Trie) { + select { + case p.copies <- copy: + default: + } +} + +func (p *prefetcherTrie) GetAccount(address common.Address) (*types.StateAccount, error) { + p.wg.Add(1) + f := func() { + defer p.wg.Done() + + tr := p.getCopy() + _, err := tr.GetAccount(address) + if err != nil { + log.Error("GetAccount failed in prefetcher", "err", err) + } + p.putCopy(tr) + } + p.p.workers.Execute(f) + return nil, nil // Note this result is never used by the prefetcher +} + +func (p *prefetcherTrie) GetStorage(address common.Address, key []byte) ([]byte, error) { + p.wg.Add(1) + f := func() { + defer p.wg.Done() + + tr := p.getCopy() + _, err := tr.GetStorage(address, key) + if err != nil { + log.Error("GetAccount failed in prefetcher", "err", err) + } + p.putCopy(tr) + } + p.p.workers.Execute(f) + return nil, nil // Note this result is never used by the prefetcher +} diff --git a/core/state/statedb.go b/core/state/statedb.go index 9eda070321..70d9fa4fd8 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -210,7 +210,14 @@ func (s *StateDB) StartPrefetcher(namespace string, maxConcurrency int) { s.prefetcher = nil } if s.snap != nil { - s.prefetcher = newTriePrefetcher(s.db, s.originalRoot, namespace, maxConcurrency) + db := s.db + type prefetchingDB interface { + ForPrefetchingOnly(db Database, maxConcurrency int) Database + } + if p, ok := db.(prefetchingDB); ok { + db = p.ForPrefetchingOnly(db, maxConcurrency) + } + s.prefetcher = newTriePrefetcher(db, s.originalRoot, namespace) } } diff --git a/core/state/trie_prefetcher.go b/core/state/trie_prefetcher.go index 5b01083f59..bd491b5f67 100644 --- a/core/state/trie_prefetcher.go +++ b/core/state/trie_prefetcher.go @@ -28,16 +28,16 @@ package state import ( "sync" - "time" - "github.com/ava-labs/coreth/metrics" - "github.com/ava-labs/coreth/utils" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" ) -// triePrefetchMetricsPrefix is the prefix under which to publish the metrics. -const triePrefetchMetricsPrefix = "trie/prefetch/" +var ( + // triePrefetchMetricsPrefix is the prefix under which to publish the metrics. + triePrefetchMetricsPrefix = "trie/prefetch/" +) // triePrefetcher is an active prefetcher, which receives accounts or storage // items and does trie-loading of them. The goal is to get as much useful content @@ -50,91 +50,65 @@ type triePrefetcher struct { fetches map[string]Trie // Partially or fully fetched tries. Only populated for inactive copies. fetchers map[string]*subfetcher // Subfetchers for each trie - maxConcurrency int - workers *utils.BoundedWorkers - - subfetcherWorkersMeter metrics.Meter - subfetcherWaitTimer metrics.Counter - subfetcherCopiesMeter metrics.Meter - + deliveryMissMeter metrics.Meter accountLoadMeter metrics.Meter accountDupMeter metrics.Meter accountSkipMeter metrics.Meter accountWasteMeter metrics.Meter - - storageFetchersMeter metrics.Meter - storageLoadMeter metrics.Meter - storageLargestLoadMeter metrics.Meter - storageDupMeter metrics.Meter - storageSkipMeter metrics.Meter - storageWasteMeter metrics.Meter + storageLoadMeter metrics.Meter + storageDupMeter metrics.Meter + storageSkipMeter metrics.Meter + storageWasteMeter metrics.Meter } -func newTriePrefetcher(db Database, root common.Hash, namespace string, maxConcurrency int) *triePrefetcher { +func newTriePrefetcher(db Database, root common.Hash, namespace string) *triePrefetcher { prefix := triePrefetchMetricsPrefix + namespace - return &triePrefetcher{ + p := &triePrefetcher{ db: db, root: root, fetchers: make(map[string]*subfetcher), // Active prefetchers use the fetchers map - maxConcurrency: maxConcurrency, - workers: utils.NewBoundedWorkers(maxConcurrency), // Scale up as needed to [maxConcurrency] - - subfetcherWorkersMeter: metrics.GetOrRegisterMeter(prefix+"/subfetcher/workers", nil), - subfetcherWaitTimer: metrics.GetOrRegisterCounter(prefix+"/subfetcher/wait", nil), - subfetcherCopiesMeter: metrics.GetOrRegisterMeter(prefix+"/subfetcher/copies", nil), - + deliveryMissMeter: metrics.GetOrRegisterMeter(prefix+"/deliverymiss", nil), accountLoadMeter: metrics.GetOrRegisterMeter(prefix+"/account/load", nil), accountDupMeter: metrics.GetOrRegisterMeter(prefix+"/account/dup", nil), accountSkipMeter: metrics.GetOrRegisterMeter(prefix+"/account/skip", nil), accountWasteMeter: metrics.GetOrRegisterMeter(prefix+"/account/waste", nil), - - storageFetchersMeter: metrics.GetOrRegisterMeter(prefix+"/storage/fetchers", nil), - storageLoadMeter: metrics.GetOrRegisterMeter(prefix+"/storage/load", nil), - storageLargestLoadMeter: metrics.GetOrRegisterMeter(prefix+"/storage/lload", nil), - storageDupMeter: metrics.GetOrRegisterMeter(prefix+"/storage/dup", nil), - storageSkipMeter: metrics.GetOrRegisterMeter(prefix+"/storage/skip", nil), - storageWasteMeter: metrics.GetOrRegisterMeter(prefix+"/storage/waste", nil), + storageLoadMeter: metrics.GetOrRegisterMeter(prefix+"/storage/load", nil), + storageDupMeter: metrics.GetOrRegisterMeter(prefix+"/storage/dup", nil), + storageSkipMeter: metrics.GetOrRegisterMeter(prefix+"/storage/skip", nil), + storageWasteMeter: metrics.GetOrRegisterMeter(prefix+"/storage/waste", nil), } + return p } // close iterates over all the subfetchers, aborts any that were left spinning // and reports the stats to the metrics subsystem. func (p *triePrefetcher) close() { - // If the prefetcher is an inactive one, bail out - if p.fetches != nil { - return - } - - // Collect stats from all fetchers - var ( - storageFetchers int64 - largestLoad int64 - ) + defer func() { + type closer interface { + Close() + } + if closer, ok := p.db.(closer); ok { + closer.Close() + } + }() for _, fetcher := range p.fetchers { - fetcher.abort() // safe to call multiple times (should be a no-op on happy path) + fetcher.abort() // safe to do multiple times if metrics.Enabled { - p.subfetcherCopiesMeter.Mark(int64(fetcher.copies())) - if fetcher.root == p.root { p.accountLoadMeter.Mark(int64(len(fetcher.seen))) p.accountDupMeter.Mark(int64(fetcher.dups)) - p.accountSkipMeter.Mark(int64(fetcher.skips())) + p.accountSkipMeter.Mark(int64(len(fetcher.tasks))) for _, key := range fetcher.used { delete(fetcher.seen, string(key)) } p.accountWasteMeter.Mark(int64(len(fetcher.seen))) } else { - storageFetchers++ - oseen := int64(len(fetcher.seen)) - if oseen > largestLoad { - largestLoad = oseen - } - p.storageLoadMeter.Mark(oseen) + p.storageLoadMeter.Mark(int64(len(fetcher.seen))) p.storageDupMeter.Mark(int64(fetcher.dups)) - p.storageSkipMeter.Mark(int64(fetcher.skips())) + p.storageSkipMeter.Mark(int64(len(fetcher.tasks))) for _, key := range fetcher.used { delete(fetcher.seen, string(key)) @@ -143,20 +117,6 @@ func (p *triePrefetcher) close() { } } } - if metrics.Enabled { - p.storageFetchersMeter.Mark(storageFetchers) - p.storageLargestLoadMeter.Mark(largestLoad) - } - - // Stop all workers once fetchers are aborted (otherwise - // could stop while waiting) - // - // Record number of workers that were spawned during this run - workersUsed := int64(p.workers.Wait()) - if metrics.Enabled { - p.subfetcherWorkersMeter.Mark(workersUsed) - } - // Clear out all fetchers (will crash on a second call, deliberate) p.fetchers = nil } @@ -169,23 +129,17 @@ func (p *triePrefetcher) copy() *triePrefetcher { copy := &triePrefetcher{ db: p.db, root: p.root, - fetches: make(map[string]Trie), // Active prefetchers use the fetchers map - - subfetcherWorkersMeter: p.subfetcherWorkersMeter, - subfetcherWaitTimer: p.subfetcherWaitTimer, - subfetcherCopiesMeter: p.subfetcherCopiesMeter, + fetches: make(map[string]Trie), // Active prefetchers use the fetches map + deliveryMissMeter: p.deliveryMissMeter, accountLoadMeter: p.accountLoadMeter, accountDupMeter: p.accountDupMeter, accountSkipMeter: p.accountSkipMeter, accountWasteMeter: p.accountWasteMeter, - - storageFetchersMeter: p.storageFetchersMeter, - storageLoadMeter: p.storageLoadMeter, - storageLargestLoadMeter: p.storageLargestLoadMeter, - storageDupMeter: p.storageDupMeter, - storageSkipMeter: p.storageSkipMeter, - storageWasteMeter: p.storageWasteMeter, + storageLoadMeter: p.storageLoadMeter, + storageDupMeter: p.storageDupMeter, + storageSkipMeter: p.storageSkipMeter, + storageWasteMeter: p.storageWasteMeter, } // If the prefetcher is already a copy, duplicate the data if p.fetches != nil { @@ -210,12 +164,11 @@ func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, addr comm if p.fetches != nil { return } - // Active fetcher, schedule the retrievals id := p.trieID(owner, root) fetcher := p.fetchers[id] if fetcher == nil { - fetcher = newSubfetcher(p, owner, root, addr) + fetcher = newSubfetcher(p.db, p.root, owner, root, addr) p.fetchers[id] = fetcher } fetcher.schedule(keys) @@ -229,27 +182,24 @@ func (p *triePrefetcher) trie(owner common.Hash, root common.Hash) Trie { if p.fetches != nil { trie := p.fetches[id] if trie == nil { + p.deliveryMissMeter.Mark(1) return nil } return p.db.CopyTrie(trie) } - // Otherwise the prefetcher is active, bail if no trie was prefetched for this root fetcher := p.fetchers[id] if fetcher == nil { + p.deliveryMissMeter.Mark(1) return nil } + // Interrupt the prefetcher if it's by any chance still running and return + // a copy of any pre-loaded trie. + fetcher.abort() // safe to do multiple times - // Wait for the fetcher to finish and shutdown orchestrator, if it exists - start := time.Now() - fetcher.wait() - if metrics.Enabled { - p.subfetcherWaitTimer.Inc(time.Since(start).Milliseconds()) - } - - // Return a copy of one of the prefetched tries trie := fetcher.peek() if trie == nil { + p.deliveryMissMeter.Mark(1) return nil } return trie @@ -276,15 +226,20 @@ func (p *triePrefetcher) trieID(owner common.Hash, root common.Hash) string { // main prefetcher is paused and either all requested items are processed or if // the trie being worked on is retrieved from the prefetcher. type subfetcher struct { - p *triePrefetcher - db Database // Database to load trie nodes through state common.Hash // Root hash of the state to prefetch owner common.Hash // Owner of the trie, usually account hash root common.Hash // Root hash of the trie to prefetch addr common.Address // Address of the account that the trie belongs to + trie Trie // Trie being populated with nodes - to *trieOrchestrator // Orchestrate concurrent fetching of a single trie + tasks [][]byte // Items queued up for retrieval + lock sync.Mutex // Lock protecting the task queue + + wake chan struct{} // Wake channel if a new task is scheduled + stop chan struct{} // Channel to interrupt processing + term chan struct{} // Channel to signal interruption + copy chan chan Trie // Channel to request a copy of the current trie seen map[string]struct{} // Tracks the entries already loaded dups int // Number of duplicate preload tasks @@ -293,348 +248,143 @@ type subfetcher struct { // newSubfetcher creates a goroutine to prefetch state items belonging to a // particular root hash. -func newSubfetcher(p *triePrefetcher, owner common.Hash, root common.Hash, addr common.Address) *subfetcher { +func newSubfetcher(db Database, state common.Hash, owner common.Hash, root common.Hash, addr common.Address) *subfetcher { sf := &subfetcher{ - p: p, - db: p.db, - state: p.root, + db: db, + state: state, owner: owner, root: root, addr: addr, + wake: make(chan struct{}, 1), + stop: make(chan struct{}), + term: make(chan struct{}), + copy: make(chan chan Trie), seen: make(map[string]struct{}), } - sf.to = newTrieOrchestrator(sf) - if sf.to != nil { - go sf.to.processTasks() - } - // We return [sf] here to ensure we don't try to re-create if - // we aren't able to setup a [newTrieOrchestrator] the first time. + go sf.loop() return sf } // schedule adds a batch of trie keys to the queue to prefetch. -// This should never block, so an array is used instead of a channel. -// -// This is not thread-safe. func (sf *subfetcher) schedule(keys [][]byte) { // Append the tasks to the current queue - tasks := make([][]byte, 0, len(keys)) - for _, key := range keys { - // Check if keys already seen - sk := string(key) - if _, ok := sf.seen[sk]; ok { - sf.dups++ - continue - } - sf.seen[sk] = struct{}{} - tasks = append(tasks, key) - } + sf.lock.Lock() + sf.tasks = append(sf.tasks, keys...) + sf.lock.Unlock() - // After counting keys, exit if they can't be prefetched - if sf.to == nil { - return + // Notify the prefetcher, it's fine if it's already terminated + select { + case sf.wake <- struct{}{}: + default: } - - // Add tasks to queue for prefetching - sf.to.enqueueTasks(tasks) } // peek tries to retrieve a deep copy of the fetcher's trie in whatever form it // is currently. func (sf *subfetcher) peek() Trie { - if sf.to == nil { - return nil - } - return sf.to.copyBase() -} + ch := make(chan Trie) + select { + case sf.copy <- ch: + // Subfetcher still alive, return copy from it + return <-ch -// wait must only be called if [triePrefetcher] has not been closed. If this happens, -// workers will not finish. -func (sf *subfetcher) wait() { - if sf.to == nil { - // Unable to open trie - return + case <-sf.term: + // Subfetcher already terminated, return a copy directly + if sf.trie == nil { + return nil + } + return sf.db.CopyTrie(sf.trie) } - sf.to.wait() } +// abort interrupts the subfetcher immediately. It is safe to call abort multiple +// times but it is not thread safe. func (sf *subfetcher) abort() { - if sf.to == nil { - // Unable to open trie - return + select { + case <-sf.stop: + default: + close(sf.stop) } - sf.to.abort() -} + <-sf.term -func (sf *subfetcher) skips() int { - if sf.to == nil { - // Unable to open trie - return 0 + type waiter interface { + Wait() } - return sf.to.skipCount() -} - -func (sf *subfetcher) copies() int { - if sf.to == nil { - // Unable to open trie - return 0 + if trie, ok := sf.trie.(waiter); ok { + trie.Wait() } - return sf.to.copies } -// trieOrchestrator is not thread-safe. -type trieOrchestrator struct { - sf *subfetcher - - // base is an unmodified Trie we keep for - // creating copies for each worker goroutine. - // - // We care more about quick copies than good copies - // because most (if not all) of the nodes that will be populated - // in the copy will come from the underlying triedb cache. Ones - // that don't come from this cache probably had to be fetched - // from disk anyways. - base Trie - baseLock sync.Mutex - - tasksAllowed bool - skips int // number of tasks skipped - pendingTasks [][]byte - taskLock sync.Mutex - - processingTasks sync.WaitGroup - - wake chan struct{} - stop chan struct{} - stopOnce sync.Once - loopTerm chan struct{} - - copies int - copyChan chan Trie - copySpawner chan struct{} -} +// loop waits for new tasks to be scheduled and keeps loading them until it runs +// out of tasks or its underlying trie is retrieved for committing. +func (sf *subfetcher) loop() { + // No matter how the loop stops, signal anyone waiting that it's terminated + defer close(sf.term) -func newTrieOrchestrator(sf *subfetcher) *trieOrchestrator { // Start by opening the trie and stop processing if it fails - var ( - base Trie - err error - ) if sf.owner == (common.Hash{}) { - base, err = sf.db.OpenTrie(sf.root) + trie, err := sf.db.OpenTrie(sf.root) if err != nil { log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err) - return nil + return } + sf.trie = trie } else { // The trie argument can be nil as verkle doesn't support prefetching // yet. TODO FIX IT(rjl493456442), otherwise code will panic here. - base, err = sf.db.OpenStorageTrie(sf.state, sf.addr, sf.root, nil) + trie, err := sf.db.OpenStorageTrie(sf.state, sf.addr, sf.root, nil) if err != nil { log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err) - return nil + return } + sf.trie = trie } - - // Instantiate trieOrchestrator - to := &trieOrchestrator{ - sf: sf, - base: base, - - tasksAllowed: true, - wake: make(chan struct{}, 1), - stop: make(chan struct{}), - loopTerm: make(chan struct{}), - - copyChan: make(chan Trie, sf.p.maxConcurrency), - copySpawner: make(chan struct{}, sf.p.maxConcurrency), - } - - // Create initial trie copy - to.copies++ - to.copySpawner <- struct{}{} - to.copyChan <- to.copyBase() - return to -} - -func (to *trieOrchestrator) copyBase() Trie { - to.baseLock.Lock() - defer to.baseLock.Unlock() - - return to.sf.db.CopyTrie(to.base) -} - -func (to *trieOrchestrator) skipCount() int { - to.taskLock.Lock() - defer to.taskLock.Unlock() - - return to.skips -} - -func (to *trieOrchestrator) enqueueTasks(tasks [][]byte) { - to.taskLock.Lock() - defer to.taskLock.Unlock() - - if len(tasks) == 0 { - return - } - - // Add tasks to [pendingTasks] - if !to.tasksAllowed { - to.skips += len(tasks) - return - } - to.processingTasks.Add(len(tasks)) - to.pendingTasks = append(to.pendingTasks, tasks...) - - // Wake up processor - select { - case to.wake <- struct{}{}: - default: - } -} - -func (to *trieOrchestrator) handleStop(remaining int) { - to.taskLock.Lock() - to.skips += remaining - to.taskLock.Unlock() - to.processingTasks.Add(-remaining) -} - -func (to *trieOrchestrator) processTasks() { - defer close(to.loopTerm) - + // Trie opened successfully, keep prefetching items for { - // Determine if we should process or exit select { - case <-to.wake: - case <-to.stop: - return - } - - // Get current tasks - to.taskLock.Lock() - tasks := to.pendingTasks - to.pendingTasks = nil - to.taskLock.Unlock() - - // Enqueue more work as soon as trie copies are available - lt := len(tasks) - for i := 0; i < lt; i++ { - // Try to stop as soon as possible, if channel is closed - remaining := lt - i - select { - case <-to.stop: - to.handleStop(remaining) - return - default: - } - - // Try to create to get an active copy first (select is non-deterministic, - // so we may end up creating a new copy when we don't need to) - var t Trie - select { - case t = <-to.copyChan: - default: - // Wait for an available copy or create one, if we weren't - // able to get a previously created copy + case <-sf.wake: + // Subfetcher was woken up, retrieve any tasks to avoid spinning the lock + sf.lock.Lock() + tasks := sf.tasks + sf.tasks = nil + sf.lock.Unlock() + + // Prefetch any tasks until the loop is interrupted + for i, task := range tasks { select { - case <-to.stop: - to.handleStop(remaining) + case <-sf.stop: + // If termination is requested, add any leftover back and return + sf.lock.Lock() + sf.tasks = append(sf.tasks, tasks[i:]...) + sf.lock.Unlock() return - case t = <-to.copyChan: - case to.copySpawner <- struct{}{}: - to.copies++ - t = to.copyBase() - } - } - // Enqueue work, unless stopped. - fTask := tasks[i] - f := func() { - // Perform task - var err error - if len(fTask) == common.AddressLength { - _, err = t.GetAccount(common.BytesToAddress(fTask)) - } else { - _, err = t.GetStorage(to.sf.addr, fTask) + case ch := <-sf.copy: + // Somebody wants a copy of the current trie, grant them + ch <- sf.db.CopyTrie(sf.trie) + + default: + // No termination request yet, prefetch the next entry + if _, ok := sf.seen[string(task)]; ok { + sf.dups++ + } else { + if len(task) == common.AddressLength { + sf.trie.GetAccount(common.BytesToAddress(task)) + } else { + sf.trie.GetStorage(sf.addr, task) + } + sf.seen[string(task)] = struct{}{} + } } - if err != nil { - log.Error("Trie prefetcher failed fetching", "root", to.sf.root, "err", err) - } - to.processingTasks.Done() - - // Return copy when we are done with it, so someone else can use it - // - // channel is buffered and will not block - to.copyChan <- t } - // Enqueue task for processing (may spawn new goroutine - // if not at [maxConcurrency]) - // - // If workers are stopped before calling [Execute], this function may - // panic. - to.sf.p.workers.Execute(f) - } - } -} - -func (to *trieOrchestrator) stopAcceptingTasks() { - to.taskLock.Lock() - defer to.taskLock.Unlock() + case ch := <-sf.copy: + // Somebody wants a copy of the current trie, grant them + ch <- sf.db.CopyTrie(sf.trie) - if !to.tasksAllowed { - return + case <-sf.stop: + // Termination is requested, abort and leave remaining tasks + return + } } - to.tasksAllowed = false - - // We don't clear [to.pendingTasks] here because - // it will be faster to prefetch them even though we - // are still waiting. -} - -// wait stops accepting new tasks and waits for ongoing tasks to complete. If -// wait is called, it is not necessary to call [abort]. -// -// It is safe to call wait multiple times. -func (to *trieOrchestrator) wait() { - // Prevent more tasks from being enqueued - to.stopAcceptingTasks() - - // Wait for processing tasks to complete - to.processingTasks.Wait() - - // Stop orchestrator loop - to.stopOnce.Do(func() { - close(to.stop) - }) - <-to.loopTerm -} - -// abort stops any ongoing tasks and shuts down the orchestrator loop. If abort -// is called, it is not necessary to call [wait]. -// -// It is safe to call abort multiple times. -func (to *trieOrchestrator) abort() { - // Prevent more tasks from being enqueued - to.stopAcceptingTasks() - - // Stop orchestrator loop - to.stopOnce.Do(func() { - close(to.stop) - }) - <-to.loopTerm - - // Capture any dangling pending tasks (processTasks - // may exit before enqueing all pendingTasks) - to.taskLock.Lock() - pendingCount := len(to.pendingTasks) - to.skips += pendingCount - to.pendingTasks = nil - to.taskLock.Unlock() - to.processingTasks.Add(-pendingCount) - - // Wait for processing tasks to complete - to.processingTasks.Wait() } diff --git a/core/state/trie_prefetcher_test.go b/core/state/trie_prefetcher_test.go index 4f70b49179..fc5301cc24 100644 --- a/core/state/trie_prefetcher_test.go +++ b/core/state/trie_prefetcher_test.go @@ -59,7 +59,8 @@ func filledStateDB() *StateDB { func TestCopyAndClose(t *testing.T) { db := filledStateDB() - prefetcher := newTriePrefetcher(db.db, db.originalRoot, "", maxConcurrency) + prefetchDb := newPrefetcherDatabase(db.db, maxConcurrency) + prefetcher := newTriePrefetcher(prefetchDb, db.originalRoot, "") skey := common.HexToHash("aaa") prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()}) prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()}) @@ -84,7 +85,8 @@ func TestCopyAndClose(t *testing.T) { func TestUseAfterClose(t *testing.T) { db := filledStateDB() - prefetcher := newTriePrefetcher(db.db, db.originalRoot, "", maxConcurrency) + prefetchDb := newPrefetcherDatabase(db.db, maxConcurrency) + prefetcher := newTriePrefetcher(prefetchDb, db.originalRoot, "") skey := common.HexToHash("aaa") prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()}) a := prefetcher.trie(common.Hash{}, db.originalRoot) @@ -100,7 +102,8 @@ func TestUseAfterClose(t *testing.T) { func TestCopyClose(t *testing.T) { db := filledStateDB() - prefetcher := newTriePrefetcher(db.db, db.originalRoot, "", maxConcurrency) + prefetchDb := newPrefetcherDatabase(db.db, maxConcurrency) + prefetcher := newTriePrefetcher(prefetchDb, db.originalRoot, "") skey := common.HexToHash("aaa") prefetcher.prefetch(common.Hash{}, db.originalRoot, common.Address{}, [][]byte{skey.Bytes()}) cpy := prefetcher.copy()