From a7b3ab4e1ea5e0d47f4a79b69f8f948cf16e90af Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Fri, 11 Oct 2024 22:23:48 +0200 Subject: [PATCH] Address code review comments Signed-off-by: Yacov Manevich --- snow/engine/common/timer.go | 68 +++++++++---------- snow/engine/common/timer_test.go | 53 +++++++-------- snow/engine/snowman/bootstrap/bootstrapper.go | 2 +- 3 files changed, 59 insertions(+), 64 deletions(-) diff --git a/snow/engine/common/timer.go b/snow/engine/common/timer.go index 0ca23bb9a89..440434a3444 100644 --- a/snow/engine/common/timer.go +++ b/snow/engine/common/timer.go @@ -8,7 +8,7 @@ import ( "time" ) -// PreemptionSignal signals when to preempt the pendingTimeout of the timeout handler. +// PreemptionSignal signals when to preempt the pendingTimeoutToken of the timeout handler. type PreemptionSignal struct { activateOnce sync.Once initOnce sync.Once @@ -37,38 +37,39 @@ func (ps *PreemptionSignal) Preempt() { // Only a single timeout can be pending to be scheduled at any given time. // Once a preemption signal is closed, all timeouts are immediately dispatched. type timeoutScheduler struct { - newTimer func(duration time.Duration) *time.Timer - onTimeout func() - preemptionSignal <-chan struct{} - pendingTimeout chan struct{} + newTimer func(duration time.Duration) *time.Timer + onTimeout func() + preemptionSignal <-chan struct{} + pendingTimeoutToken chan struct{} } // NewTimeoutScheduler constructs a new timeout scheduler with the given function to be invoked upon a timeout, // unless the preemptionSignal is closed and in which case it invokes the function immediately. -func NewTimeoutScheduler(onTimeout func(), preemptionSignal <-chan struct{}, newTimer func(duration time.Duration) *time.Timer) *timeoutScheduler { +func NewTimeoutScheduler(onTimeout func(), preemptionSignal <-chan struct{}) *timeoutScheduler { pendingTimout := make(chan struct{}, 1) pendingTimout <- struct{}{} return &timeoutScheduler{ - preemptionSignal: preemptionSignal, - newTimer: newTimer, - onTimeout: onTimeout, - pendingTimeout: pendingTimout, + preemptionSignal: preemptionSignal, + newTimer: time.NewTimer, + onTimeout: onTimeout, + pendingTimeoutToken: pendingTimout, } } // RegisterTimeout fires the function the timeout scheduler is initialized with no later than the given timeout. func (th *timeoutScheduler) RegisterTimeout(d time.Duration) { - acquiredToken := th.acquirePendingTimeoutToken() - preempted := th.preempted() - - if !preempted && !acquiredToken { + // There can only be a single timeout pending at any time, and once a timeout is scheduled, + // we prevent future timeouts to be scheduled until the timeout triggers by taking the pendingTimeoutToken. + // Any subsequent attempt to register a timeout would fail obtaining the pendingTimeoutToken, + // and return. + if !th.acquirePendingTimeoutToken() { return } - go th.scheduleTimeout(d, acquiredToken) + go th.scheduleTimeout(d) } -func (th *timeoutScheduler) scheduleTimeout(d time.Duration, acquiredToken bool) { +func (th *timeoutScheduler) scheduleTimeout(d time.Duration) { timer := th.newTimer(d) defer timer.Stop() @@ -79,37 +80,32 @@ func (th *timeoutScheduler) scheduleTimeout(d time.Duration, acquiredToken bool) case <-th.preemptionSignal: } - if acquiredToken { - th.relinquishPendingTimeoutToken() - } -} - -func (th *timeoutScheduler) preempted() bool { - select { - case <-th.preemptionSignal: - return true - default: - return false - } + // Relinquish the pendingTimeoutToken. + // This is needed to be done before onTimeout() is invoked, + // and that's why onTimeout() is deferred to be called at the end of the function. + // If we trigger the timeout prematurely before we relinquish the pendingTimeoutToken, + // A subsequent timeout scheduling attempt that originates from the triggering of the current timeout + // will fail, as the pendingTimeoutToken is not yet available. + th.pendingTimeoutToken <- struct{}{} } func (th *timeoutScheduler) acquirePendingTimeoutToken() bool { select { - case <-th.pendingTimeout: + case <-th.pendingTimeoutToken: return true default: return false } } -func (th *timeoutScheduler) relinquishPendingTimeoutToken() { - th.pendingTimeout <- struct{}{} -} - // TimeoutRegistrar describes the standard interface for specifying a timeout type TimeoutRegistrar interface { - // RegisterTimeout specifies how much time to delay the next timeout message - // by. If the subnet has been bootstrapped, the timeout will fire - // immediately via calling Preempt(). + // RegisterTimeout specifies how much time to delay the next timeout message by. + // + // If there is already a pending timeout message, this call is a no-op. + // However, it is guaranteed that the timeout will fire at least once after + // calling this function. + // + // If the subnet has been bootstrapped, the timeout will fire immediately via calling Preempt(). RegisterTimeout(time.Duration) } diff --git a/snow/engine/common/timer_test.go b/snow/engine/common/timer_test.go index 67f169bfaa1..cee88276d5d 100644 --- a/snow/engine/common/timer_test.go +++ b/snow/engine/common/timer_test.go @@ -19,15 +19,15 @@ func TestTimeoutScheduler(t *testing.T) { advanceTime func(chan time.Time) }{ { - desc: "multiple pendingTimeout one after the other with preemption", + desc: "multiple pendingTimeoutToken one after the other with preemption", expectedInvocationCount: 10, shouldPreempt: true, clock: make(chan time.Time, 1), - initClock: func(_ chan time.Time) {}, - advanceTime: func(_ chan time.Time) {}, + initClock: func(chan time.Time) {}, + advanceTime: func(chan time.Time) {}, }, { - desc: "multiple pendingTimeout one after the other", + desc: "multiple pendingTimeoutToken one after the other", expectedInvocationCount: 10, clock: make(chan time.Time, 1), initClock: func(clock chan time.Time) { @@ -57,14 +57,7 @@ func TestTimeoutScheduler(t *testing.T) { // in order to make the tests deterministic. order := make(chan struct{}) - newTimer := func(_ time.Duration) *time.Timer { - // We use a duration of 0 to not leave a lingering timer - // after the test finishes. - // Then we replace the time channel to have control over the timer. - timer := time.NewTimer(0) - timer.C = testCase.clock - return timer - } + newTimer := makeMockedTimer(testCase.clock) onTimeout := func() { order <- struct{}{} @@ -72,7 +65,8 @@ func TestTimeoutScheduler(t *testing.T) { testCase.advanceTime(testCase.clock) } - ts := NewTimeoutScheduler(onTimeout, ps, newTimer) + ts := NewTimeoutScheduler(onTimeout, ps) + ts.newTimer = newTimer for i := 0; i < testCase.expectedInvocationCount; i++ { ts.RegisterTimeout(time.Hour) @@ -85,26 +79,19 @@ func TestTimeoutScheduler(t *testing.T) { } func TestTimeoutSchedulerConcurrentRegister(_ *testing.T) { + // Not enough invocations means the test would stall. + // Too many invocations means a negative counter panic. + clock := make(chan time.Time, 2) - newTimer := func(_ time.Duration) *time.Timer { - // We use a duration of 0 to not leave a lingering timer - // after the test finishes. - // Then we replace the time channel to have control over the timer. - timer := time.NewTimer(0) - timer.C = clock - return timer - } + newTimer := makeMockedTimer(clock) var wg sync.WaitGroup wg.Add(1) - onTimeout := func() { - wg.Done() - } - - roChan := make(<-chan struct{}) + preemptChan := make(<-chan struct{}) - ts := NewTimeoutScheduler(onTimeout, roChan, newTimer) + ts := NewTimeoutScheduler(wg.Done, preemptChan) + ts.newTimer = newTimer ts.RegisterTimeout(time.Hour) // First timeout is registered ts.RegisterTimeout(time.Hour) // Second should not @@ -115,3 +102,15 @@ func TestTimeoutSchedulerConcurrentRegister(_ *testing.T) { wg.Wait() } + +func makeMockedTimer(clock chan time.Time) func(_ time.Duration) *time.Timer { + newTimer := func(_ time.Duration) *time.Timer { + // We use a duration of 0 to not leave a lingering timer + // after the test finishes. + // Then we replace the time channel to have control over the timer. + timer := time.NewTimer(0) + timer.C = clock + return timer + } + return newTimer +} diff --git a/snow/engine/snowman/bootstrap/bootstrapper.go b/snow/engine/snowman/bootstrap/bootstrapper.go index 82312fab3d2..024925c7928 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper.go +++ b/snow/engine/snowman/bootstrap/bootstrapper.go @@ -149,7 +149,7 @@ func New(config Config, onFinished func(ctx context.Context, lastReqID uint32) e bs.Config.Ctx.Log.Warn("Encountered error during bootstrapping: %w", zap.Error(err)) } } - bs.TimeoutRegistrar = common.NewTimeoutScheduler(timeout, config.BootstrapTracker.AllBootstrapped(), time.NewTimer) + bs.TimeoutRegistrar = common.NewTimeoutScheduler(timeout, config.BootstrapTracker.AllBootstrapped()) return bs, err }