Skip to content

Commit

Permalink
refactor(share): GetShare -> GetSamples (#3891)
Browse files Browse the repository at this point in the history
Co-authored-by: Oleg Kovalov <[email protected]>
  • Loading branch information
Wondertan and cristaloleg authored Nov 27, 2024
1 parent f5d9b32 commit 5d9192f
Show file tree
Hide file tree
Showing 35 changed files with 563 additions and 326 deletions.
37 changes: 25 additions & 12 deletions blob/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,11 @@ func TestBlobService_Get(t *testing.T) {
shareOffset := 0
for i := range blobs {
row, col := calculateIndex(len(h.DAH.RowRoots), blobs[i].index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx := shwap.SampleCoords{Row: row, Col: col}
require.NoError(t, err)
require.True(t, bytes.Equal(sh.ToBytes(), resultShares[shareOffset].ToBytes()),
smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleCoords{idx})
require.NoError(t, err)
require.True(t, bytes.Equal(smpls[0].Share.ToBytes(), resultShares[shareOffset].ToBytes()),
fmt.Sprintf("issue on %d attempt. ROW:%d, COL: %d, blobIndex:%d", i, row, col, blobs[i].index),
)
shareOffset += libshare.SparseSharesNeeded(uint32(len(blobs[i].Data())))
Expand Down Expand Up @@ -487,10 +489,13 @@ func TestService_GetSingleBlobWithoutPadding(t *testing.T) {
h, err := service.headerGetter(ctx, 1)
require.NoError(t, err)
row, col := calculateIndex(len(h.DAH.RowRoots), newBlob.index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx := shwap.SampleCoords{Row: row, Col: col}
require.NoError(t, err)

smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleCoords{idx})
require.NoError(t, err)

assert.Equal(t, sh, resultShares[0])
assert.Equal(t, smpls[0].Share, resultShares[0])
}

func TestService_Get(t *testing.T) {
Expand Down Expand Up @@ -521,10 +526,13 @@ func TestService_Get(t *testing.T) {
assert.Equal(t, b.Commitment, blob.Commitment)

row, col := calculateIndex(len(h.DAH.RowRoots), b.index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx := shwap.SampleCoords{Row: row, Col: col}
require.NoError(t, err)

assert.Equal(t, sh, resultShares[shareOffset], fmt.Sprintf("issue on %d attempt", i))
smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleCoords{idx})
require.NoError(t, err)

assert.Equal(t, smpls[0].Share, resultShares[shareOffset], fmt.Sprintf("issue on %d attempt", i))
shareOffset += libshare.SparseSharesNeeded(uint32(len(blob.Data())))
}
}
Expand Down Expand Up @@ -580,10 +588,13 @@ func TestService_GetAllWithoutPadding(t *testing.T) {
require.True(t, blobs[i].compareCommitments(blob.Commitment))

row, col := calculateIndex(len(h.DAH.RowRoots), blob.index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx := shwap.SampleCoords{Row: row, Col: col}
require.NoError(t, err)

smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleCoords{idx})
require.NoError(t, err)

assert.Equal(t, sh, resultShares[shareOffset])
assert.Equal(t, smpls[0].Share, resultShares[shareOffset])
shareOffset += libshare.SparseSharesNeeded(uint32(len(blob.Data())))
}
}
Expand Down Expand Up @@ -902,10 +913,12 @@ func createService(ctx context.Context, t testing.TB, shares []libshare.Share) *
nd, err := eds.NamespaceData(ctx, accessor, ns)
return nd, err
})
shareGetter.EXPECT().GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().
DoAndReturn(func(ctx context.Context, h *header.ExtendedHeader, row, col int) (libshare.Share, error) {
s, err := accessor.Sample(ctx, row, col)
return s.Share, err
shareGetter.EXPECT().GetSamples(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().
DoAndReturn(func(ctx context.Context, h *header.ExtendedHeader,
indices []shwap.SampleCoords,
) ([]shwap.Sample, error) {
smpl, err := accessor.Sample(ctx, indices[0])
return []shwap.Sample{smpl}, err
})

// create header and put it into the store
Expand Down
44 changes: 30 additions & 14 deletions nodebuilder/share/mocks/api.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 29 additions & 1 deletion nodebuilder/share/share.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
libshare "github.com/celestiaorg/go-square/v2/share"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/header"
headerServ "github.com/celestiaorg/celestia-node/nodebuilder/header"
"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/eds"
Expand Down Expand Up @@ -45,6 +46,8 @@ type Module interface {
SharesAvailable(ctx context.Context, height uint64) error
// GetShare gets a Share by coordinates in EDS.
GetShare(ctx context.Context, height uint64, row, col int) (libshare.Share, error)
// GetSamples gets sample for given indices.
GetSamples(ctx context.Context, header *header.ExtendedHeader, indices []shwap.SampleCoords) ([]shwap.Sample, error)
// GetEDS gets the full EDS identified by the given extended header.
GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error)
// GetNamespaceData gets all shares from an EDS within the given namespace.
Expand All @@ -65,6 +68,11 @@ type API struct {
height uint64,
row, col int,
) (libshare.Share, error) `perm:"read"`
GetSamples func(
ctx context.Context,
header *header.ExtendedHeader,
indices []shwap.SampleCoords,
) ([]shwap.Sample, error) `perm:"read"`
GetEDS func(
ctx context.Context,
height uint64,
Expand All @@ -90,6 +98,12 @@ func (api *API) GetShare(ctx context.Context, height uint64, row, col int) (libs
return api.Internal.GetShare(ctx, height, row, col)
}

func (api *API) GetSamples(ctx context.Context, header *header.ExtendedHeader,
indices []shwap.SampleCoords,
) ([]shwap.Sample, error) {
return api.Internal.GetSamples(ctx, header, indices)
}

func (api *API) GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) {
return api.Internal.GetEDS(ctx, height)
}
Expand Down Expand Up @@ -117,7 +131,21 @@ func (m module) GetShare(ctx context.Context, height uint64, row, col int) (libs
if err != nil {
return libshare.Share{}, err
}
return m.getter.GetShare(ctx, header, row, col)

idx := shwap.SampleCoords{Row: row, Col: col}

smpls, err := m.getter.GetSamples(ctx, header, []shwap.SampleCoords{idx})
if err != nil {
return libshare.Share{}, err
}

return smpls[0].Share, nil
}

func (m module) GetSamples(ctx context.Context, header *header.ExtendedHeader,
indices []shwap.SampleCoords,
) ([]shwap.Sample, error) {
return m.getter.GetSamples(ctx, header, indices)
}

func (m module) GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) {
Expand Down
50 changes: 24 additions & 26 deletions share/availability/light/availability.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,12 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
return nil
}

var (
mutex sync.Mutex
failedSamples []Sample
wg sync.WaitGroup
)
log.Debugw("starting sampling session", "root", dah.String())

log.Debugw("starting sampling session", "height", header.Height())
idxs := make([]shwap.SampleCoords, len(samples.Remaining))
for i, s := range samples.Remaining {
idxs[i] = shwap.SampleCoords{Row: s.Row, Col: s.Col}
}

// remove one second from the deadline to ensure we have enough time to process the results
samplingCtx, cancel := context.WithCancel(ctx)
Expand All @@ -129,25 +128,21 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
}
defer cancel()

// Concurrently sample shares
for _, s := range samples.Remaining {
wg.Add(1)
go func(s Sample) {
defer wg.Done()
_, err := la.getter.GetShare(samplingCtx, header, s.Row, s.Col)
mutex.Lock()
defer mutex.Unlock()
if err != nil {
log.Debugw("error fetching share", "height", header.Height(), "row", s.Row, "col", s.Col)
failedSamples = append(failedSamples, s)
} else {
samples.Available = append(samples.Available, s)
}
}(s)
smpls, errGetSamples := la.getter.GetSamples(samplingCtx, header, idxs)
if len(smpls) == 0 {
return share.ErrNotAvailable
}

var failedSamples []shwap.SampleCoords

for i, smpl := range smpls {
if smpl.IsEmpty() {
failedSamples = append(failedSamples, shwap.SampleCoords{Row: idxs[i].Row, Col: idxs[i].Col})
} else {
samples.Available = append(samples.Available, shwap.SampleCoords{Row: idxs[i].Row, Col: idxs[i].Col})
}
}
wg.Wait()

// Update remaining samples with failed ones
samples.Remaining = failedSamples

// Store the updated sampling result
Expand All @@ -162,16 +157,17 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
return fmt.Errorf("store sampling result: %w", err)
}

if errors.Is(ctx.Err(), context.Canceled) {
if errors.Is(errGetSamples, context.Canceled) {
// Availability did not complete due to context cancellation, return context error instead of
// share.ErrNotAvailable
return ctx.Err()
return context.Canceled
}

// if any of the samples failed, return an error
if len(failedSamples) > 0 {
return share.ErrNotAvailable
}

return nil
}

Expand Down Expand Up @@ -210,7 +206,9 @@ func (la *ShareAvailability) Prune(ctx context.Context, h *header.ExtendedHeader

// delete stored samples
for _, sample := range result.Available {
blk, err := bitswap.NewEmptySampleBlock(h.Height(), sample.Row, sample.Col, len(h.DAH.RowRoots))
idx := shwap.SampleCoords{Row: sample.Row, Col: sample.Col}

blk, err := bitswap.NewEmptySampleBlock(h.Height(), idx, len(h.DAH.RowRoots))
if err != nil {
return fmt.Errorf("marshal sample ID: %w", err)
}
Expand Down
Loading

0 comments on commit 5d9192f

Please sign in to comment.