Skip to content

Commit

Permalink
fix(availability): prevent parallel availability calls (#3883)
Browse files Browse the repository at this point in the history
  • Loading branch information
walldiss authored Oct 31, 2024
1 parent 816f46e commit 105ab1b
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 20 deletions.
45 changes: 45 additions & 0 deletions libs/utils/sessions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package utils

import (
"context"
"sync"
)

// Sessions manages concurrent sessions for the specified key.
// It ensures only one session can proceed for each key, avoiding duplicate efforts.
// If a session is already active for the given key, it waits until the session completes or
// context error occurs.
type Sessions struct {
active sync.Map
}

func NewSessions() *Sessions {
return &Sessions{}
}

// StartSession attempts to start a new session for the given key. It provides a release function
// to clean up the session lock for this key, once the session is complete.
func (s *Sessions) StartSession(
ctx context.Context,
key any,
) (endSession func(), err error) {
// Attempt to load or initialize a channel to track the sampling session for this height
lockChan, alreadyActive := s.active.LoadOrStore(key, make(chan struct{}))
if alreadyActive {
// If a session is already active, wait for it to complete
select {
case <-lockChan.(chan struct{}):
case <-ctx.Done():
return func() {}, ctx.Err()
}
// previous session has completed, try to obtain the lock for this session
return s.StartSession(ctx, key)
}

// Provide a function to release the lock once session is complete
releaseLock := func() {
close(lockChan.(chan struct{}))
s.active.Delete(key)
}
return releaseLock, nil
}
127 changes: 127 additions & 0 deletions libs/utils/sessions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package utils

import (
"context"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
)

// TestSessionsSerialExecution verifies that multiple sessions for the same key are executed
// sequentially.
func TestSessionsSerialExecution(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
t.Cleanup(cancel)

sessions := NewSessions()
key := "testKey"
activeCount := atomic.Int32{}
var wg sync.WaitGroup

numSessions := 20

for i := 0; i < numSessions; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
endSession, err := sessions.StartSession(ctx, key)
require.NoError(t, err)
old := activeCount.Add(1)
require.Equal(t, int32(1), old)
// Simulate some work
time.Sleep(50 * time.Millisecond)
old = activeCount.Add(-1)
require.Equal(t, int32(0), old)
// Release the session
endSession()
}(i)
}

wg.Wait()
}

func TestSessionsContextCancellation(t *testing.T) {
sessions := NewSessions()
key := "testCancelKey"

// Start the first session which will hold the lock for a while
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

release, err := sessions.StartSession(ctx, key)
if err != nil {
t.Errorf("First session: failed to start: %v", err)
return
}

// Hold the session for 1 second
time.Sleep(1 * time.Second)
release()
}()

// Give the first goroutine a moment to acquire the session
time.Sleep(100 * time.Millisecond)

// Attempt to start a second session with a context that times out before the first session releases
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
t.Cleanup(cancel)

_, err := sessions.StartSession(ctx, key)
require.ErrorIs(t, err, context.DeadlineExceeded)

// Attempt to start a second session with a context that is canceled before the first session
// releases
ctx, cancel = context.WithCancel(context.Background())
cancel()

_, err = sessions.StartSession(ctx, key)
require.ErrorIs(t, err, context.Canceled)

wg.Wait()
}

// TestSessions_ConcurrentDifferentKeys ensures that sessions with different keys run concurrently.
func TestSessions_ConcurrentDifferentKeys(t *testing.T) {
sessions := NewSessions()
numKeys := 20
var wg sync.WaitGroup
startCh := make(chan struct{})
activeSessions := atomic.Int32{}
maxActive := int32(0)

for i := 0; i < numKeys; i++ {
wg.Add(1)
go func(key int) {
defer wg.Done()
ctx := context.Background()
endSession, err := sessions.StartSession(ctx, key)
require.NoError(t, err)

active := activeSessions.Add(1)
if active > maxActive {
maxActive = active
}

// Wait to simulate work
time.Sleep(100 * time.Millisecond)

activeSessions.Add(-1)
endSession()
}(i)
}

// Start all goroutines
close(startCh)
wg.Wait()

if maxActive > int32(numKeys) {
t.Errorf("Expected %d concurrent active sessions, but got %d", numKeys, maxActive)
}
}
20 changes: 15 additions & 5 deletions share/availability/light/availability.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
logging "github.com/ipfs/go-log/v2"

"github.com/celestiaorg/celestia-node/header"
"github.com/celestiaorg/celestia-node/libs/utils"
"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)
Expand All @@ -31,8 +32,9 @@ type ShareAvailability struct {
getter shwap.Getter
params Parameters

dsLk sync.RWMutex
ds *autobatch.Datastore
activeHeights *utils.Sessions
dsLk sync.RWMutex
ds *autobatch.Datastore
}

// NewShareAvailability creates a new light Availability.
Expand All @@ -50,9 +52,10 @@ func NewShareAvailability(
}

return &ShareAvailability{
getter: getter,
params: params,
ds: autoDS,
getter: getter,
params: params,
activeHeights: utils.NewSessions(),
ds: autoDS,
}
}

Expand All @@ -65,6 +68,13 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
return nil
}

// Prevent multiple sampling sessions for the same header height
release, err := la.activeHeights.StartSession(ctx, header.Height())
if err != nil {
return err
}
defer release()

// load snapshot of the last sampling errors from disk
key := datastoreKeyForRoot(dah)
la.dsLk.RLock()
Expand Down
67 changes: 52 additions & 15 deletions share/availability/light/availability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,47 +147,84 @@ func TestSharesAvailableFailed(t *testing.T) {
require.Len(t, failed, int(avail.params.SampleAmount))

// Simulate a getter that now returns shares successfully
successfulGetter := newOnceGetter()
successfulGetter.AddSamples(failed)
avail.getter = successfulGetter
onceGetter := newOnceGetter()
avail.getter = onceGetter

// should be able to retrieve all the failed samples now
err = avail.SharesAvailable(ctx, eh)
require.NoError(t, err)

// onceGetter should have no more samples stored after the call
require.Empty(t, successfulGetter.available)
onceGetter.checkOnce(t)
require.ElementsMatch(t, failed, onceGetter.sampledList())
}

func TestParallelAvailability(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ds := datastore.NewMapDatastore()
// Simulate a getter that returns shares successfully
successfulGetter := newOnceGetter()
avail := NewShareAvailability(successfulGetter, ds)

// create new eds, that is not available by getter
eds := edstest.RandEDS(t, 16)
roots, err := share.NewAxisRoots(eds)
require.NoError(t, err)
eh := headertest.RandExtendedHeaderWithRoot(t, roots)

var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
err := avail.SharesAvailable(ctx, eh)
require.NoError(t, err)
}()
}
wg.Wait()
require.Len(t, successfulGetter.sampledList(), int(avail.params.SampleAmount))
}

type onceGetter struct {
*sync.Mutex
available map[Sample]struct{}
sampled map[Sample]int
}

func newOnceGetter() onceGetter {
return onceGetter{
Mutex: &sync.Mutex{},
available: make(map[Sample]struct{}),
Mutex: &sync.Mutex{},
sampled: make(map[Sample]int),
}
}

func (m onceGetter) checkOnce(t *testing.T) {
m.Lock()
defer m.Unlock()
for s, count := range m.sampled {
if count > 1 {
t.Errorf("sample %v was called more than once", s)
}
}
}

func (m onceGetter) AddSamples(samples []Sample) {
func (m onceGetter) sampledList() []Sample {
m.Lock()
defer m.Unlock()
for _, s := range samples {
m.available[s] = struct{}{}
samples := make([]Sample, 0, len(m.sampled))
for s := range m.sampled {
samples = append(samples, s)
}
return samples
}

func (m onceGetter) GetShare(_ context.Context, _ *header.ExtendedHeader, row, col int) (libshare.Share, error) {
m.Lock()
defer m.Unlock()
s := Sample{Row: row, Col: col}
if _, ok := m.available[s]; ok {
delete(m.available, s)
return libshare.Share{}, nil
}
return libshare.Share{}, share.ErrNotAvailable
m.sampled[s]++
return libshare.Share{}, nil
}

func (m onceGetter) GetEDS(_ context.Context, _ *header.ExtendedHeader) (*rsmt2d.ExtendedDataSquare, error) {
Expand Down

0 comments on commit 105ab1b

Please sign in to comment.