Skip to content

Commit

Permalink
Address code review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Yacov Manevich <[email protected]>
  • Loading branch information
yacovm committed Oct 11, 2024
1 parent bdf0063 commit a7b3ab4
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 64 deletions.
68 changes: 32 additions & 36 deletions snow/engine/common/timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"time"
)

// PreemptionSignal signals when to preempt the pendingTimeout of the timeout handler.
// PreemptionSignal signals when to preempt the pendingTimeoutToken of the timeout handler.
type PreemptionSignal struct {
activateOnce sync.Once
initOnce sync.Once
Expand Down Expand Up @@ -37,38 +37,39 @@ func (ps *PreemptionSignal) Preempt() {
// Only a single timeout can be pending to be scheduled at any given time.
// Once a preemption signal is closed, all timeouts are immediately dispatched.
type timeoutScheduler struct {
newTimer func(duration time.Duration) *time.Timer
onTimeout func()
preemptionSignal <-chan struct{}
pendingTimeout chan struct{}
newTimer func(duration time.Duration) *time.Timer
onTimeout func()
preemptionSignal <-chan struct{}
pendingTimeoutToken chan struct{}
}

// NewTimeoutScheduler constructs a new timeout scheduler with the given function to be invoked upon a timeout,
// unless the preemptionSignal is closed and in which case it invokes the function immediately.
func NewTimeoutScheduler(onTimeout func(), preemptionSignal <-chan struct{}, newTimer func(duration time.Duration) *time.Timer) *timeoutScheduler {
func NewTimeoutScheduler(onTimeout func(), preemptionSignal <-chan struct{}) *timeoutScheduler {
pendingTimout := make(chan struct{}, 1)
pendingTimout <- struct{}{}
return &timeoutScheduler{
preemptionSignal: preemptionSignal,
newTimer: newTimer,
onTimeout: onTimeout,
pendingTimeout: pendingTimout,
preemptionSignal: preemptionSignal,
newTimer: time.NewTimer,
onTimeout: onTimeout,
pendingTimeoutToken: pendingTimout,
}
}

// RegisterTimeout fires the function the timeout scheduler is initialized with no later than the given timeout.
func (th *timeoutScheduler) RegisterTimeout(d time.Duration) {
acquiredToken := th.acquirePendingTimeoutToken()
preempted := th.preempted()

if !preempted && !acquiredToken {
// There can only be a single timeout pending at any time, and once a timeout is scheduled,
// we prevent future timeouts to be scheduled until the timeout triggers by taking the pendingTimeoutToken.
// Any subsequent attempt to register a timeout would fail obtaining the pendingTimeoutToken,
// and return.
if !th.acquirePendingTimeoutToken() {
return
}

go th.scheduleTimeout(d, acquiredToken)
go th.scheduleTimeout(d)
}

func (th *timeoutScheduler) scheduleTimeout(d time.Duration, acquiredToken bool) {
func (th *timeoutScheduler) scheduleTimeout(d time.Duration) {
timer := th.newTimer(d)
defer timer.Stop()

Expand All @@ -79,37 +80,32 @@ func (th *timeoutScheduler) scheduleTimeout(d time.Duration, acquiredToken bool)
case <-th.preemptionSignal:
}

if acquiredToken {
th.relinquishPendingTimeoutToken()
}
}

func (th *timeoutScheduler) preempted() bool {
select {
case <-th.preemptionSignal:
return true
default:
return false
}
// Relinquish the pendingTimeoutToken.
// This is needed to be done before onTimeout() is invoked,
// and that's why onTimeout() is deferred to be called at the end of the function.
// If we trigger the timeout prematurely before we relinquish the pendingTimeoutToken,
// A subsequent timeout scheduling attempt that originates from the triggering of the current timeout
// will fail, as the pendingTimeoutToken is not yet available.
th.pendingTimeoutToken <- struct{}{}
}

func (th *timeoutScheduler) acquirePendingTimeoutToken() bool {
select {
case <-th.pendingTimeout:
case <-th.pendingTimeoutToken:
return true
default:
return false
}
}

func (th *timeoutScheduler) relinquishPendingTimeoutToken() {
th.pendingTimeout <- struct{}{}
}

// TimeoutRegistrar describes the standard interface for specifying a timeout
type TimeoutRegistrar interface {
// RegisterTimeout specifies how much time to delay the next timeout message
// by. If the subnet has been bootstrapped, the timeout will fire
// immediately via calling Preempt().
// RegisterTimeout specifies how much time to delay the next timeout message by.
//
// If there is already a pending timeout message, this call is a no-op.
// However, it is guaranteed that the timeout will fire at least once after
// calling this function.
//
// If the subnet has been bootstrapped, the timeout will fire immediately via calling Preempt().
RegisterTimeout(time.Duration)
}
53 changes: 26 additions & 27 deletions snow/engine/common/timer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ func TestTimeoutScheduler(t *testing.T) {
advanceTime func(chan time.Time)
}{
{
desc: "multiple pendingTimeout one after the other with preemption",
desc: "multiple pendingTimeoutToken one after the other with preemption",
expectedInvocationCount: 10,
shouldPreempt: true,
clock: make(chan time.Time, 1),
initClock: func(_ chan time.Time) {},
advanceTime: func(_ chan time.Time) {},
initClock: func(chan time.Time) {},
advanceTime: func(chan time.Time) {},
},
{
desc: "multiple pendingTimeout one after the other",
desc: "multiple pendingTimeoutToken one after the other",
expectedInvocationCount: 10,
clock: make(chan time.Time, 1),
initClock: func(clock chan time.Time) {
Expand Down Expand Up @@ -57,22 +57,16 @@ func TestTimeoutScheduler(t *testing.T) {
// in order to make the tests deterministic.
order := make(chan struct{})

newTimer := func(_ time.Duration) *time.Timer {
// We use a duration of 0 to not leave a lingering timer
// after the test finishes.
// Then we replace the time channel to have control over the timer.
timer := time.NewTimer(0)
timer.C = testCase.clock
return timer
}
newTimer := makeMockedTimer(testCase.clock)

onTimeout := func() {
order <- struct{}{}
wg.Done()
testCase.advanceTime(testCase.clock)
}

ts := NewTimeoutScheduler(onTimeout, ps, newTimer)
ts := NewTimeoutScheduler(onTimeout, ps)
ts.newTimer = newTimer

for i := 0; i < testCase.expectedInvocationCount; i++ {
ts.RegisterTimeout(time.Hour)
Expand All @@ -85,26 +79,19 @@ func TestTimeoutScheduler(t *testing.T) {
}

func TestTimeoutSchedulerConcurrentRegister(_ *testing.T) {
// Not enough invocations means the test would stall.
// Too many invocations means a negative counter panic.

clock := make(chan time.Time, 2)
newTimer := func(_ time.Duration) *time.Timer {
// We use a duration of 0 to not leave a lingering timer
// after the test finishes.
// Then we replace the time channel to have control over the timer.
timer := time.NewTimer(0)
timer.C = clock
return timer
}
newTimer := makeMockedTimer(clock)

var wg sync.WaitGroup
wg.Add(1)

onTimeout := func() {
wg.Done()
}

roChan := make(<-chan struct{})
preemptChan := make(<-chan struct{})

ts := NewTimeoutScheduler(onTimeout, roChan, newTimer)
ts := NewTimeoutScheduler(wg.Done, preemptChan)
ts.newTimer = newTimer

ts.RegisterTimeout(time.Hour) // First timeout is registered
ts.RegisterTimeout(time.Hour) // Second should not
Expand All @@ -115,3 +102,15 @@ func TestTimeoutSchedulerConcurrentRegister(_ *testing.T) {

wg.Wait()
}

func makeMockedTimer(clock chan time.Time) func(_ time.Duration) *time.Timer {
newTimer := func(_ time.Duration) *time.Timer {
// We use a duration of 0 to not leave a lingering timer
// after the test finishes.
// Then we replace the time channel to have control over the timer.
timer := time.NewTimer(0)
timer.C = clock
return timer
}
return newTimer
}
2 changes: 1 addition & 1 deletion snow/engine/snowman/bootstrap/bootstrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func New(config Config, onFinished func(ctx context.Context, lastReqID uint32) e
bs.Config.Ctx.Log.Warn("Encountered error during bootstrapping: %w", zap.Error(err))
}
}
bs.TimeoutRegistrar = common.NewTimeoutScheduler(timeout, config.BootstrapTracker.AllBootstrapped(), time.NewTimer)
bs.TimeoutRegistrar = common.NewTimeoutScheduler(timeout, config.BootstrapTracker.AllBootstrapped())

return bs, err
}
Expand Down

0 comments on commit a7b3ab4

Please sign in to comment.