Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(share): GetShare -> GetSamples #3905

Merged
merged 11 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions blob/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,11 @@ func TestBlobService_Get(t *testing.T) {
shareOffset := 0
for i := range blobs {
row, col := calculateIndex(len(h.DAH.RowRoots), blobs[i].index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots))
require.NoError(t, err)
require.True(t, bytes.Equal(sh.ToBytes(), resultShares[shareOffset].ToBytes()),
smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx})
require.NoError(t, err)
require.True(t, bytes.Equal(smpls[0].Share.ToBytes(), resultShares[shareOffset].ToBytes()),
cristaloleg marked this conversation as resolved.
Show resolved Hide resolved
fmt.Sprintf("issue on %d attempt. ROW:%d, COL: %d, blobIndex:%d", i, row, col, blobs[i].index),
)
shareOffset += libshare.SparseSharesNeeded(uint32(len(blobs[i].Data())))
Expand Down Expand Up @@ -487,10 +489,13 @@ func TestService_GetSingleBlobWithoutPadding(t *testing.T) {
h, err := service.headerGetter(ctx, 1)
require.NoError(t, err)
row, col := calculateIndex(len(h.DAH.RowRoots), newBlob.index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots))
require.NoError(t, err)

smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx})
require.NoError(t, err)

assert.Equal(t, sh, resultShares[0])
assert.Equal(t, smpls[0].Share, resultShares[0])
cristaloleg marked this conversation as resolved.
Show resolved Hide resolved
}

func TestService_Get(t *testing.T) {
Expand Down Expand Up @@ -521,10 +526,13 @@ func TestService_Get(t *testing.T) {
assert.Equal(t, b.Commitment, blob.Commitment)

row, col := calculateIndex(len(h.DAH.RowRoots), b.index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots))
require.NoError(t, err)

assert.Equal(t, sh, resultShares[shareOffset], fmt.Sprintf("issue on %d attempt", i))
smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx})
require.NoError(t, err)

assert.Equal(t, smpls[0].Share, resultShares[shareOffset], fmt.Sprintf("issue on %d attempt", i))
cristaloleg marked this conversation as resolved.
Show resolved Hide resolved
shareOffset += libshare.SparseSharesNeeded(uint32(len(blob.Data())))
}
}
Expand Down Expand Up @@ -580,10 +588,13 @@ func TestService_GetAllWithoutPadding(t *testing.T) {
require.True(t, blobs[i].compareCommitments(blob.Commitment))

row, col := calculateIndex(len(h.DAH.RowRoots), blob.index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots))
require.NoError(t, err)

smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx})
require.NoError(t, err)

assert.Equal(t, sh, resultShares[shareOffset])
assert.Equal(t, smpls[0].Share, resultShares[shareOffset])
cristaloleg marked this conversation as resolved.
Show resolved Hide resolved
shareOffset += libshare.SparseSharesNeeded(uint32(len(blob.Data())))
}
}
Expand Down Expand Up @@ -904,7 +915,8 @@ func createService(ctx context.Context, t testing.TB, shares []libshare.Share) *
})
shareGetter.EXPECT().GetSamples(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().
DoAndReturn(func(ctx context.Context, h *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) {
return smpls, nil
smpl, err := accessor.Sample(ctx, indices[0])
return []shwap.Sample{smpl}, err
})

// create header and put it into the store
Expand Down
20 changes: 20 additions & 0 deletions nodebuilder/share/share.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
libshare "github.com/celestiaorg/go-square/v2/share"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/header"
headerServ "github.com/celestiaorg/celestia-node/nodebuilder/header"
"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/eds"
Expand Down Expand Up @@ -45,6 +46,8 @@ type Module interface {
SharesAvailable(ctx context.Context, height uint64) error
// GetShare gets a Share by coordinates in EDS.
GetShare(ctx context.Context, height uint64, row, col int) (libshare.Share, error)
// GetSamples gets sample for given indices.
GetSamples(ctx context.Context, header *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error)
cristaloleg marked this conversation as resolved.
Show resolved Hide resolved
// GetEDS gets the full EDS identified by the given extended header.
GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error)
// GetSharesByNamespace gets all shares from an EDS within the given namespace.
Expand All @@ -65,6 +68,11 @@ type API struct {
height uint64,
row, col int,
) (libshare.Share, error) `perm:"read"`
GetSamples func(
ctx context.Context,
header *header.ExtendedHeader,
indices []shwap.SampleIndex,
) ([]shwap.Sample, error) `perm:"read"`
GetEDS func(
ctx context.Context,
height uint64,
Expand All @@ -90,6 +98,12 @@ func (api *API) GetShare(ctx context.Context, height uint64, row, col int) (libs
return api.Internal.GetShare(ctx, height, row, col)
}

func (api *API) GetSamples(ctx context.Context, header *header.ExtendedHeader,
indices []shwap.SampleIndex,
) ([]shwap.Sample, error) {
return api.Internal.GetSamples(ctx, header, indices)
}

func (api *API) GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) {
return api.Internal.GetEDS(ctx, height)
}
Expand Down Expand Up @@ -132,6 +146,12 @@ func (m module) GetShare(ctx context.Context, height uint64, row, col int) (libs
return smpls[0].Share, nil
}

func (m module) GetSamples(ctx context.Context, header *header.ExtendedHeader,
indices []shwap.SampleIndex,
) ([]shwap.Sample, error) {
return m.getter.GetSamples(ctx, header, indices)
}

func (m module) GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) {
header, err := m.hs.GetByHeight(ctx, height)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions share/availability/light/availability.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
}

smpls, err := la.getter.GetSamples(ctx, header, idxs)
if errors.Is(ctx.Err(), context.Canceled) {
if errors.Is(err, context.Canceled) {
cristaloleg marked this conversation as resolved.
Show resolved Hide resolved
// Availability did not complete due to context cancellation, return context error instead of
// share.ErrNotAvailable
return ctx.Err()
return err
}
if len(smpls) == 0 {
return share.ErrNotAvailable
Expand Down
11 changes: 4 additions & 7 deletions share/availability/light/availability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@ func TestSharesAvailableCaches(t *testing.T) {
acc := eds.Rsmt2D{ExtendedDataSquare: square}
smpls := make([]shwap.Sample, len(indices))
for i, idx := range indices {
rowIdx, colIdx, err := idx.Coordinates(len(hdr.DAH.RowRoots))
if err != nil {
return nil, err
}

smpl, err := acc.Sample(ctx, rowIdx, colIdx)
smpl, err := acc.Sample(ctx, idx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -187,7 +182,9 @@ func (m onceGetter) AddSamples(samples []Sample) {
}
}

func (m onceGetter) GetSamples(_ context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) {
func (m onceGetter) GetSamples(_ context.Context, hdr *header.ExtendedHeader,
indices []shwap.SampleIndex,
) ([]shwap.Sample, error) {
m.Lock()
defer m.Unlock()

Expand Down
3 changes: 1 addition & 2 deletions share/eds/accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ type Accessor interface {
// Sample returns share and corresponding proof for row and column indices. Implementation can
// choose which axis to use for proof. Chosen axis for proof should be indicated in the returned
// Sample.
// TODO(@Wondertan): change to SampleIndex
Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error)
Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error)
// AxisHalf returns half of shares axis of the given type and index. Side is determined by
// implementation. Implementations should indicate the side in the returned AxisHalf.
AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error)
Expand Down
4 changes: 2 additions & 2 deletions share/eds/close_once.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ func (c *closeOnce) AxisRoots(ctx context.Context) (*share.AxisRoots, error) {
return c.f.AxisRoots(ctx)
}

func (c *closeOnce) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) {
func (c *closeOnce) Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) {
if c.closed.Load() {
return shwap.Sample{}, errAccessorClosed
}
return c.f.Sample(ctx, rowIdx, colIdx)
return c.f.Sample(ctx, idx)
}

func (c *closeOnce) AxisHalf(
Expand Down
6 changes: 3 additions & 3 deletions share/eds/close_once_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestWithClosedOnce(t *testing.T) {
stub := &stubEdsAccessorCloser{}
closedOnce := WithClosedOnce(stub)

_, err := closedOnce.Sample(ctx, 0, 0)
_, err := closedOnce.Sample(ctx, 0)
require.NoError(t, err)
_, err = closedOnce.AxisHalf(ctx, rsmt2d.Row, 0)
require.NoError(t, err)
Expand All @@ -33,7 +33,7 @@ func TestWithClosedOnce(t *testing.T) {
require.True(t, stub.closed)

// Ensure that the underlying file is not accessible after closing
_, err = closedOnce.Sample(ctx, 0, 0)
_, err = closedOnce.Sample(ctx, 0)
require.ErrorIs(t, err, errAccessorClosed)
_, err = closedOnce.AxisHalf(ctx, rsmt2d.Row, 0)
require.ErrorIs(t, err, errAccessorClosed)
Expand All @@ -59,7 +59,7 @@ func (s *stubEdsAccessorCloser) AxisRoots(context.Context) (*share.AxisRoots, er
return &share.AxisRoots{}, nil
}

func (s *stubEdsAccessorCloser) Sample(context.Context, int, int) (shwap.Sample, error) {
func (s *stubEdsAccessorCloser) Sample(context.Context, shwap.SampleIndex) (shwap.Sample, error) {
return shwap.Sample{}, nil
}

Expand Down
7 changes: 6 additions & 1 deletion share/eds/proofs_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ func (c *proofsCache) AxisRoots(ctx context.Context) (*share.AxisRoots, error) {
return roots, nil
}

func (c *proofsCache) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) {
func (c *proofsCache) Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) {
rowIdx, colIdx, err := idx.Coordinates(c.Size(ctx))
if err != nil {
return shwap.Sample{}, err
}

axisType, axisIdx, shrIdx := rsmt2d.Row, rowIdx, colIdx
ax, err := c.axisWithProofs(ctx, axisType, axisIdx)
if err != nil {
Expand Down
11 changes: 8 additions & 3 deletions share/eds/rsmt2d.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,22 @@ func (eds *Rsmt2D) AxisRoots(context.Context) (*share.AxisRoots, error) {
// Sample returns share and corresponding proof for row and column indices.
func (eds *Rsmt2D) Sample(
_ context.Context,
rowIdx, colIdx int,
idx shwap.SampleIndex,
) (shwap.Sample, error) {
return eds.SampleForProofAxis(rowIdx, colIdx, rsmt2d.Row)
return eds.SampleForProofAxis(idx, rsmt2d.Row)
}

// SampleForProofAxis samples a share from an Extended Data Square based on the provided
// row and column indices and proof axis. It returns a sample with the share and proof.
func (eds *Rsmt2D) SampleForProofAxis(
rowIdx, colIdx int,
idx shwap.SampleIndex,
proofType rsmt2d.Axis,
) (shwap.Sample, error) {
rowIdx, colIdx, err := idx.Coordinates(int(eds.Width()))
if err != nil {
return shwap.Sample{}, err
}

axisIdx, shrIdx := relativeIndexes(rowIdx, colIdx, proofType)
shares, err := getAxis(eds.ExtendedDataSquare, proofType, axisIdx)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion share/eds/rsmt2d_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ func TestRsmt2dSampleForProofAxis(t *testing.T) {
for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} {
for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ {
for colIdx := 0; colIdx < odsSize*2; colIdx++ {
sample, err := accessor.SampleForProofAxis(rowIdx, colIdx, proofType)
idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, accessor.Size(context.Background()))
require.NoError(t, err)

sample, err := accessor.SampleForProofAxis(idx, proofType)
require.NoError(t, err)

want := eds.GetCell(uint(rowIdx), uint(colIdx))
Expand Down
30 changes: 20 additions & 10 deletions share/eds/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ func testAccessorSample(
// t.Parallel() this fails the test for some reason
for rowIdx := 0; rowIdx < width; rowIdx++ {
for colIdx := 0; colIdx < width; colIdx++ {
testSample(ctx, t, acc, roots, colIdx, rowIdx)
idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, acc.Size(ctx))
require.NoError(t, err)
testSample(ctx, t, acc, roots, idx)
}
}
})
Expand All @@ -162,10 +164,12 @@ func testAccessorSample(
for rowIdx := 0; rowIdx < width; rowIdx++ {
for colIdx := 0; colIdx < width; colIdx++ {
wg.Add(1)
go func(rowIdx, colIdx int) {
idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, acc.Size(ctx))
require.NoError(t, err)
go func(idx shwap.SampleIndex) {
defer wg.Done()
testSample(ctx, t, acc, roots, rowIdx, colIdx)
}(rowIdx, colIdx)
testSample(ctx, t, acc, roots, idx)
}(idx)
}
}
wg.Wait()
Expand All @@ -182,8 +186,8 @@ func testAccessorSample(
wg.Add(1)
go func() {
defer wg.Done()
rowIdx, colIdx := rand.IntN(width), rand.IntN(width) //nolint:gosec
testSample(ctx, t, acc, roots, rowIdx, colIdx)
idx := rand.IntN(int(eds.Width())) //nolint:gosec
testSample(ctx, t, acc, roots, shwap.SampleIndex(idx))
}()
}
wg.Wait()
Expand All @@ -195,9 +199,12 @@ func testSample(
t *testing.T,
acc Accessor,
roots *share.AxisRoots,
rowIdx, colIdx int,
idx shwap.SampleIndex,
) {
shr, err := acc.Sample(ctx, rowIdx, colIdx)
shr, err := acc.Sample(ctx, idx)
require.NoError(t, err)

rowIdx, colIdx, err := idx.Coordinates(acc.Size(ctx))
require.NoError(t, err)

err = shr.Verify(roots, rowIdx, colIdx)
Expand Down Expand Up @@ -444,13 +451,16 @@ func BenchGetSampleFromAccessor(
name := fmt.Sprintf("Size:%v/quadrant:%s", size, q)
b.Run(name, func(b *testing.B) {
rowIdx, colIdx := q.coordinates(acc.Size(ctx))
idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, acc.Size(ctx))
require.NoError(b, err)

// warm up cache
_, err := acc.Sample(ctx, rowIdx, colIdx)
_, err = acc.Sample(ctx, idx)
require.NoError(b, err, q.String())

b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := acc.Sample(ctx, rowIdx, colIdx)
_, err := acc.Sample(ctx, idx)
require.NoError(b, err)
}
})
Expand Down
11 changes: 3 additions & 8 deletions share/eds/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,12 @@ func (f validation) Size(ctx context.Context) int {
return int(size)
}

func (f validation) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) {
idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, f.Size(ctx))
if err != nil {
return shwap.Sample{}, err
}

_, err = shwap.NewSampleID(1, idx, f.Size(ctx))
func (f validation) Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) {
_, err := shwap.NewSampleID(1, idx, f.Size(ctx))
if err != nil {
return shwap.Sample{}, fmt.Errorf("sample validation: %w", err)
}
return f.Accessor.Sample(ctx, rowIdx, colIdx)
return f.Accessor.Sample(ctx, idx)
}

func (f validation) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) {
Expand Down
9 changes: 8 additions & 1 deletion share/eds/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ func TestValidation_Sample(t *testing.T) {
accessor := &Rsmt2D{ExtendedDataSquare: randEDS}
validation := WithValidation(AccessorAndStreamer(accessor, nil))

_, err := validation.Sample(context.Background(), tt.rowIdx, tt.colIdx)
idx, err := shwap.SampleIndexFromCoordinates(tt.rowIdx, tt.colIdx, accessor.Size(context.Background()))
if tt.expectFail {
require.ErrorIs(t, err, shwap.ErrInvalidID, tt.name)
return
}
require.NoError(t, err, tt.name)

_, err = validation.Sample(context.Background(), idx)
if tt.expectFail {
require.ErrorIs(t, err, shwap.ErrInvalidID)
} else {
Expand Down
4 changes: 3 additions & 1 deletion share/shwap/getters/cascade.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ func NewCascadeGetter(getters []shwap.Getter) *CascadeGetter {
}

// GetSamples gets samples from any of registered shwap.Getters in cascading order.
func (cg *CascadeGetter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) {
func (cg *CascadeGetter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader,
indices []shwap.SampleIndex,
) ([]shwap.Sample, error) {
ctx, span := tracer.Start(ctx, "cascade/get-samples", trace.WithAttributes(
attribute.Int("amount", len(indices)),
))
Expand Down
Loading
Loading