diff --git a/blob/service_test.go b/blob/service_test.go index e24c2acde8..10213900d4 100644 --- a/blob/service_test.go +++ b/blob/service_test.go @@ -145,9 +145,11 @@ func TestBlobService_Get(t *testing.T) { shareOffset := 0 for i := range blobs { row, col := calculateIndex(len(h.DAH.RowRoots), blobs[i].index) - sh, err := service.shareGetter.GetShare(ctx, h, row, col) + idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots)) require.NoError(t, err) - require.True(t, bytes.Equal(sh.ToBytes(), resultShares[shareOffset].ToBytes()), + smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx}) + require.NoError(t, err) + require.True(t, bytes.Equal(smpls[0].Share.ToBytes(), resultShares[shareOffset].ToBytes()), fmt.Sprintf("issue on %d attempt. ROW:%d, COL: %d, blobIndex:%d", i, row, col, blobs[i].index), ) shareOffset += libshare.SparseSharesNeeded(uint32(len(blobs[i].Data()))) @@ -487,10 +489,13 @@ func TestService_GetSingleBlobWithoutPadding(t *testing.T) { h, err := service.headerGetter(ctx, 1) require.NoError(t, err) row, col := calculateIndex(len(h.DAH.RowRoots), newBlob.index) - sh, err := service.shareGetter.GetShare(ctx, h, row, col) + idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots)) + require.NoError(t, err) + + smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx}) require.NoError(t, err) - assert.Equal(t, sh, resultShares[0]) + assert.Equal(t, smpls[0].Share, resultShares[0]) } func TestService_Get(t *testing.T) { @@ -521,10 +526,13 @@ func TestService_Get(t *testing.T) { assert.Equal(t, b.Commitment, blob.Commitment) row, col := calculateIndex(len(h.DAH.RowRoots), b.index) - sh, err := service.shareGetter.GetShare(ctx, h, row, col) + idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots)) require.NoError(t, err) - assert.Equal(t, sh, resultShares[shareOffset], fmt.Sprintf("issue on %d attempt", i)) + smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx}) + require.NoError(t, err) + + assert.Equal(t, smpls[0].Share, resultShares[shareOffset], fmt.Sprintf("issue on %d attempt", i)) shareOffset += libshare.SparseSharesNeeded(uint32(len(blob.Data()))) } } @@ -580,10 +588,13 @@ func TestService_GetAllWithoutPadding(t *testing.T) { require.True(t, blobs[i].compareCommitments(blob.Commitment)) row, col := calculateIndex(len(h.DAH.RowRoots), blob.index) - sh, err := service.shareGetter.GetShare(ctx, h, row, col) + idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots)) + require.NoError(t, err) + + smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx}) require.NoError(t, err) - assert.Equal(t, sh, resultShares[shareOffset]) + assert.Equal(t, smpls[0].Share, resultShares[shareOffset]) shareOffset += libshare.SparseSharesNeeded(uint32(len(blob.Data()))) } } @@ -904,7 +915,8 @@ func createService(ctx context.Context, t testing.TB, shares []libshare.Share) * }) shareGetter.EXPECT().GetSamples(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes(). DoAndReturn(func(ctx context.Context, h *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) { - return smpls, nil + smpl, err := accessor.Sample(ctx, indices[0]) + return []shwap.Sample{smpl}, err }) // create header and put it into the store diff --git a/nodebuilder/share/share.go b/nodebuilder/share/share.go index f709bd572f..d9be323b49 100644 --- a/nodebuilder/share/share.go +++ b/nodebuilder/share/share.go @@ -8,6 +8,7 @@ import ( libshare "github.com/celestiaorg/go-square/v2/share" "github.com/celestiaorg/rsmt2d" + "github.com/celestiaorg/celestia-node/header" headerServ "github.com/celestiaorg/celestia-node/nodebuilder/header" "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/eds" @@ -45,6 +46,8 @@ type Module interface { SharesAvailable(ctx context.Context, height uint64) error // GetShare gets a Share by coordinates in EDS. GetShare(ctx context.Context, height uint64, row, col int) (libshare.Share, error) + // GetSamples gets sample for given indices. + GetSamples(ctx context.Context, header *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) // GetEDS gets the full EDS identified by the given extended header. GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) // GetSharesByNamespace gets all shares from an EDS within the given namespace. @@ -65,6 +68,11 @@ type API struct { height uint64, row, col int, ) (libshare.Share, error) `perm:"read"` + GetSamples func( + ctx context.Context, + header *header.ExtendedHeader, + indices []shwap.SampleIndex, + ) ([]shwap.Sample, error) `perm:"read"` GetEDS func( ctx context.Context, height uint64, @@ -90,6 +98,12 @@ func (api *API) GetShare(ctx context.Context, height uint64, row, col int) (libs return api.Internal.GetShare(ctx, height, row, col) } +func (api *API) GetSamples(ctx context.Context, header *header.ExtendedHeader, + indices []shwap.SampleIndex, +) ([]shwap.Sample, error) { + return api.Internal.GetSamples(ctx, header, indices) +} + func (api *API) GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) { return api.Internal.GetEDS(ctx, height) } @@ -132,6 +146,12 @@ func (m module) GetShare(ctx context.Context, height uint64, row, col int) (libs return smpls[0].Share, nil } +func (m module) GetSamples(ctx context.Context, header *header.ExtendedHeader, + indices []shwap.SampleIndex, +) ([]shwap.Sample, error) { + return m.getter.GetSamples(ctx, header, indices) +} + func (m module) GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) { header, err := m.hs.GetByHeight(ctx, height) if err != nil { diff --git a/share/availability/light/availability.go b/share/availability/light/availability.go index 59976bd0d0..3975aacf0c 100644 --- a/share/availability/light/availability.go +++ b/share/availability/light/availability.go @@ -113,10 +113,10 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header } smpls, err := la.getter.GetSamples(ctx, header, idxs) - if errors.Is(ctx.Err(), context.Canceled) { + if errors.Is(err, context.Canceled) { // Availability did not complete due to context cancellation, return context error instead of // share.ErrNotAvailable - return ctx.Err() + return err } if len(smpls) == 0 { return share.ErrNotAvailable diff --git a/share/availability/light/availability_test.go b/share/availability/light/availability_test.go index 3cd86ecd4f..62c8fcc4b2 100644 --- a/share/availability/light/availability_test.go +++ b/share/availability/light/availability_test.go @@ -41,12 +41,7 @@ func TestSharesAvailableCaches(t *testing.T) { acc := eds.Rsmt2D{ExtendedDataSquare: square} smpls := make([]shwap.Sample, len(indices)) for i, idx := range indices { - rowIdx, colIdx, err := idx.Coordinates(len(hdr.DAH.RowRoots)) - if err != nil { - return nil, err - } - - smpl, err := acc.Sample(ctx, rowIdx, colIdx) + smpl, err := acc.Sample(ctx, idx) if err != nil { return nil, err } @@ -187,7 +182,9 @@ func (m onceGetter) AddSamples(samples []Sample) { } } -func (m onceGetter) GetSamples(_ context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) { +func (m onceGetter) GetSamples(_ context.Context, hdr *header.ExtendedHeader, + indices []shwap.SampleIndex, +) ([]shwap.Sample, error) { m.Lock() defer m.Unlock() diff --git a/share/eds/accessor.go b/share/eds/accessor.go index 81f55fcf64..e262382790 100644 --- a/share/eds/accessor.go +++ b/share/eds/accessor.go @@ -25,8 +25,7 @@ type Accessor interface { // Sample returns share and corresponding proof for row and column indices. Implementation can // choose which axis to use for proof. Chosen axis for proof should be indicated in the returned // Sample. - // TODO(@Wondertan): change to SampleIndex - Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) + Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) // AxisHalf returns half of shares axis of the given type and index. Side is determined by // implementation. Implementations should indicate the side in the returned AxisHalf. AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) diff --git a/share/eds/close_once.go b/share/eds/close_once.go index cc217710ce..6de62419fa 100644 --- a/share/eds/close_once.go +++ b/share/eds/close_once.go @@ -57,11 +57,11 @@ func (c *closeOnce) AxisRoots(ctx context.Context) (*share.AxisRoots, error) { return c.f.AxisRoots(ctx) } -func (c *closeOnce) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { +func (c *closeOnce) Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) { if c.closed.Load() { return shwap.Sample{}, errAccessorClosed } - return c.f.Sample(ctx, rowIdx, colIdx) + return c.f.Sample(ctx, idx) } func (c *closeOnce) AxisHalf( diff --git a/share/eds/close_once_test.go b/share/eds/close_once_test.go index d515ac7bda..59b4452174 100644 --- a/share/eds/close_once_test.go +++ b/share/eds/close_once_test.go @@ -20,7 +20,7 @@ func TestWithClosedOnce(t *testing.T) { stub := &stubEdsAccessorCloser{} closedOnce := WithClosedOnce(stub) - _, err := closedOnce.Sample(ctx, 0, 0) + _, err := closedOnce.Sample(ctx, 0) require.NoError(t, err) _, err = closedOnce.AxisHalf(ctx, rsmt2d.Row, 0) require.NoError(t, err) @@ -33,7 +33,7 @@ func TestWithClosedOnce(t *testing.T) { require.True(t, stub.closed) // Ensure that the underlying file is not accessible after closing - _, err = closedOnce.Sample(ctx, 0, 0) + _, err = closedOnce.Sample(ctx, 0) require.ErrorIs(t, err, errAccessorClosed) _, err = closedOnce.AxisHalf(ctx, rsmt2d.Row, 0) require.ErrorIs(t, err, errAccessorClosed) @@ -59,7 +59,7 @@ func (s *stubEdsAccessorCloser) AxisRoots(context.Context) (*share.AxisRoots, er return &share.AxisRoots{}, nil } -func (s *stubEdsAccessorCloser) Sample(context.Context, int, int) (shwap.Sample, error) { +func (s *stubEdsAccessorCloser) Sample(context.Context, shwap.SampleIndex) (shwap.Sample, error) { return shwap.Sample{}, nil } diff --git a/share/eds/proofs_cache.go b/share/eds/proofs_cache.go index e777b82962..006d1e46ac 100644 --- a/share/eds/proofs_cache.go +++ b/share/eds/proofs_cache.go @@ -112,7 +112,12 @@ func (c *proofsCache) AxisRoots(ctx context.Context) (*share.AxisRoots, error) { return roots, nil } -func (c *proofsCache) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { +func (c *proofsCache) Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) { + rowIdx, colIdx, err := idx.Coordinates(c.Size(ctx)) + if err != nil { + return shwap.Sample{}, err + } + axisType, axisIdx, shrIdx := rsmt2d.Row, rowIdx, colIdx ax, err := c.axisWithProofs(ctx, axisType, axisIdx) if err != nil { diff --git a/share/eds/rsmt2d.go b/share/eds/rsmt2d.go index e0e945fccb..d7ee5376a5 100644 --- a/share/eds/rsmt2d.go +++ b/share/eds/rsmt2d.go @@ -46,17 +46,22 @@ func (eds *Rsmt2D) AxisRoots(context.Context) (*share.AxisRoots, error) { // Sample returns share and corresponding proof for row and column indices. func (eds *Rsmt2D) Sample( _ context.Context, - rowIdx, colIdx int, + idx shwap.SampleIndex, ) (shwap.Sample, error) { - return eds.SampleForProofAxis(rowIdx, colIdx, rsmt2d.Row) + return eds.SampleForProofAxis(idx, rsmt2d.Row) } // SampleForProofAxis samples a share from an Extended Data Square based on the provided // row and column indices and proof axis. It returns a sample with the share and proof. func (eds *Rsmt2D) SampleForProofAxis( - rowIdx, colIdx int, + idx shwap.SampleIndex, proofType rsmt2d.Axis, ) (shwap.Sample, error) { + rowIdx, colIdx, err := idx.Coordinates(int(eds.Width())) + if err != nil { + return shwap.Sample{}, err + } + axisIdx, shrIdx := relativeIndexes(rowIdx, colIdx, proofType) shares, err := getAxis(eds.ExtendedDataSquare, proofType, axisIdx) if err != nil { diff --git a/share/eds/rsmt2d_test.go b/share/eds/rsmt2d_test.go index 96bde8c2ab..4ae001704f 100644 --- a/share/eds/rsmt2d_test.go +++ b/share/eds/rsmt2d_test.go @@ -56,7 +56,10 @@ func TestRsmt2dSampleForProofAxis(t *testing.T) { for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { for colIdx := 0; colIdx < odsSize*2; colIdx++ { - sample, err := accessor.SampleForProofAxis(rowIdx, colIdx, proofType) + idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, accessor.Size(context.Background())) + require.NoError(t, err) + + sample, err := accessor.SampleForProofAxis(idx, proofType) require.NoError(t, err) want := eds.GetCell(uint(rowIdx), uint(colIdx)) diff --git a/share/eds/testing.go b/share/eds/testing.go index 388544bc00..c4373780fb 100644 --- a/share/eds/testing.go +++ b/share/eds/testing.go @@ -148,7 +148,9 @@ func testAccessorSample( // t.Parallel() this fails the test for some reason for rowIdx := 0; rowIdx < width; rowIdx++ { for colIdx := 0; colIdx < width; colIdx++ { - testSample(ctx, t, acc, roots, colIdx, rowIdx) + idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, acc.Size(ctx)) + require.NoError(t, err) + testSample(ctx, t, acc, roots, idx) } } }) @@ -162,10 +164,12 @@ func testAccessorSample( for rowIdx := 0; rowIdx < width; rowIdx++ { for colIdx := 0; colIdx < width; colIdx++ { wg.Add(1) - go func(rowIdx, colIdx int) { + idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, acc.Size(ctx)) + require.NoError(t, err) + go func(idx shwap.SampleIndex) { defer wg.Done() - testSample(ctx, t, acc, roots, rowIdx, colIdx) - }(rowIdx, colIdx) + testSample(ctx, t, acc, roots, idx) + }(idx) } } wg.Wait() @@ -182,8 +186,8 @@ func testAccessorSample( wg.Add(1) go func() { defer wg.Done() - rowIdx, colIdx := rand.IntN(width), rand.IntN(width) //nolint:gosec - testSample(ctx, t, acc, roots, rowIdx, colIdx) + idx := rand.IntN(int(eds.Width())) //nolint:gosec + testSample(ctx, t, acc, roots, shwap.SampleIndex(idx)) }() } wg.Wait() @@ -195,9 +199,12 @@ func testSample( t *testing.T, acc Accessor, roots *share.AxisRoots, - rowIdx, colIdx int, + idx shwap.SampleIndex, ) { - shr, err := acc.Sample(ctx, rowIdx, colIdx) + shr, err := acc.Sample(ctx, idx) + require.NoError(t, err) + + rowIdx, colIdx, err := idx.Coordinates(acc.Size(ctx)) require.NoError(t, err) err = shr.Verify(roots, rowIdx, colIdx) @@ -444,13 +451,16 @@ func BenchGetSampleFromAccessor( name := fmt.Sprintf("Size:%v/quadrant:%s", size, q) b.Run(name, func(b *testing.B) { rowIdx, colIdx := q.coordinates(acc.Size(ctx)) + idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, acc.Size(ctx)) + require.NoError(b, err) + // warm up cache - _, err := acc.Sample(ctx, rowIdx, colIdx) + _, err = acc.Sample(ctx, idx) require.NoError(b, err, q.String()) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := acc.Sample(ctx, rowIdx, colIdx) + _, err := acc.Sample(ctx, idx) require.NoError(b, err) } }) diff --git a/share/eds/validation.go b/share/eds/validation.go index 29113bbf10..e59a956614 100644 --- a/share/eds/validation.go +++ b/share/eds/validation.go @@ -38,17 +38,12 @@ func (f validation) Size(ctx context.Context) int { return int(size) } -func (f validation) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { - idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, f.Size(ctx)) - if err != nil { - return shwap.Sample{}, err - } - - _, err = shwap.NewSampleID(1, idx, f.Size(ctx)) +func (f validation) Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) { + _, err := shwap.NewSampleID(1, idx, f.Size(ctx)) if err != nil { return shwap.Sample{}, fmt.Errorf("sample validation: %w", err) } - return f.Accessor.Sample(ctx, rowIdx, colIdx) + return f.Accessor.Sample(ctx, idx) } func (f validation) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) { diff --git a/share/eds/validation_test.go b/share/eds/validation_test.go index 3e645cbfb3..fbe17868a9 100644 --- a/share/eds/validation_test.go +++ b/share/eds/validation_test.go @@ -34,7 +34,14 @@ func TestValidation_Sample(t *testing.T) { accessor := &Rsmt2D{ExtendedDataSquare: randEDS} validation := WithValidation(AccessorAndStreamer(accessor, nil)) - _, err := validation.Sample(context.Background(), tt.rowIdx, tt.colIdx) + idx, err := shwap.SampleIndexFromCoordinates(tt.rowIdx, tt.colIdx, accessor.Size(context.Background())) + if tt.expectFail { + require.ErrorIs(t, err, shwap.ErrInvalidID, tt.name) + return + } + require.NoError(t, err, tt.name) + + _, err = validation.Sample(context.Background(), idx) if tt.expectFail { require.ErrorIs(t, err, shwap.ErrInvalidID) } else { diff --git a/share/shwap/getters/cascade.go b/share/shwap/getters/cascade.go index 066ee47fc8..5d0bb26b00 100644 --- a/share/shwap/getters/cascade.go +++ b/share/shwap/getters/cascade.go @@ -42,7 +42,9 @@ func NewCascadeGetter(getters []shwap.Getter) *CascadeGetter { } // GetSamples gets samples from any of registered shwap.Getters in cascading order. -func (cg *CascadeGetter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) { +func (cg *CascadeGetter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, + indices []shwap.SampleIndex, +) ([]shwap.Sample, error) { ctx, span := tracer.Start(ctx, "cascade/get-samples", trace.WithAttributes( attribute.Int("amount", len(indices)), )) diff --git a/share/shwap/getters/testing.go b/share/shwap/getters/testing.go index d81fe3c707..b7b01695d2 100644 --- a/share/shwap/getters/testing.go +++ b/share/shwap/getters/testing.go @@ -37,7 +37,9 @@ type SingleEDSGetter struct { } // GetSamples get samples from a kept EDS if exist and if the correct root is given. -func (seg *SingleEDSGetter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) { +func (seg *SingleEDSGetter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, + indices []shwap.SampleIndex, +) ([]shwap.Sample, error) { err := seg.checkRoots(hdr.DAH) if err != nil { return nil, err @@ -45,12 +47,7 @@ func (seg *SingleEDSGetter) GetSamples(ctx context.Context, hdr *header.Extended smpls := make([]shwap.Sample, len(indices)) for i, idx := range indices { - rowIdx, colIdx, err := idx.Coordinates(len(hdr.DAH.RowRoots)) - if err != nil { - return nil, err - } - - smpl, err := seg.EDS.Sample(ctx, rowIdx, colIdx) + smpl, err := seg.EDS.Sample(ctx, idx) if err != nil { return nil, err } diff --git a/share/shwap/p2p/bitswap/sample_block.go b/share/shwap/p2p/bitswap/sample_block.go index e094f58c49..fcb5a679a1 100644 --- a/share/shwap/p2p/bitswap/sample_block.go +++ b/share/shwap/p2p/bitswap/sample_block.go @@ -86,7 +86,12 @@ func (sb *SampleBlock) Marshal() ([]byte, error) { } func (sb *SampleBlock) Populate(ctx context.Context, eds eds.Accessor) error { - smpl, err := eds.Sample(ctx, sb.ID.RowIndex, sb.ID.ShareIndex) + idx, err := shwap.SampleIndexFromCoordinates(sb.ID.RowIndex, sb.ID.ShareIndex, eds.Size(ctx)) + if err != nil { + return err + } + + smpl, err := eds.Sample(ctx, idx) if err != nil { return fmt.Errorf("accessing Sample: %w", err) } diff --git a/share/shwap/sample_test.go b/share/shwap/sample_test.go index 030eeb4677..2a25d369d8 100644 --- a/share/shwap/sample_test.go +++ b/share/shwap/sample_test.go @@ -25,10 +25,13 @@ func TestSampleValidate(t *testing.T) { for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { for colIdx := 0; colIdx < odsSize*2; colIdx++ { - sample, err := inMem.SampleForProofAxis(rowIdx, colIdx, proofType) + idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, int(randEDS.Width())) require.NoError(t, err) - require.NoError(t, sample.Verify(root, rowIdx, colIdx)) + sample, err := inMem.SampleForProofAxis(idx, proofType) + require.NoError(t, err) + + require.NoError(t, sample.Verify(root, rowIdx, colIdx), "row: %d col: %d", rowIdx, colIdx) } } } @@ -42,7 +45,7 @@ func TestSampleNegativeVerifyInclusion(t *testing.T) { require.NoError(t, err) inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} - sample, err := inMem.Sample(context.Background(), 0, 0) + sample, err := inMem.Sample(context.Background(), 0) require.NoError(t, err) err = sample.Verify(root, 0, 0) require.NoError(t, err) @@ -61,14 +64,14 @@ func TestSampleNegativeVerifyInclusion(t *testing.T) { require.ErrorIs(t, err, shwap.ErrFailedVerification) // incorrect proofType - sample, err = inMem.Sample(context.Background(), 0, 0) + sample, err = inMem.Sample(context.Background(), 0) require.NoError(t, err) sample.ProofType = rsmt2d.Col err = sample.Verify(root, 0, 0) require.ErrorIs(t, err, shwap.ErrFailedVerification) // Corrupt the last root hash byte - sample, err = inMem.Sample(context.Background(), 0, 0) + sample, err = inMem.Sample(context.Background(), 0) require.NoError(t, err) root.RowRoots[0][len(root.RowRoots[0])-1] ^= 0xFF err = sample.Verify(root, 0, 0) @@ -83,7 +86,10 @@ func TestSampleProtoEncoding(t *testing.T) { for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { for colIdx := 0; colIdx < odsSize*2; colIdx++ { - sample, err := inMem.SampleForProofAxis(rowIdx, colIdx, proofType) + idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, int(randEDS.Width())) + require.NoError(t, err) + + sample, err := inMem.SampleForProofAxis(idx, proofType) require.NoError(t, err) pb := sample.ToProto() @@ -103,7 +109,8 @@ func BenchmarkSampleValidate(b *testing.B) { root, err := share.NewAxisRoots(randEDS) require.NoError(b, err) inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} - sample, err := inMem.SampleForProofAxis(0, 0, rsmt2d.Row) + + sample, err := inMem.SampleForProofAxis(0, rsmt2d.Row) require.NoError(b, err) b.ResetTimer() diff --git a/store/cache/accessor_cache_test.go b/store/cache/accessor_cache_test.go index 9c7104fbe7..82073a722c 100644 --- a/store/cache/accessor_cache_test.go +++ b/store/cache/accessor_cache_test.go @@ -315,7 +315,7 @@ func (m *mockAccessor) AxisRoots(context.Context) (*share.AxisRoots, error) { panic("implement me") } -func (m *mockAccessor) Sample(context.Context, int, int) (shwap.Sample, error) { +func (m *mockAccessor) Sample(context.Context, shwap.SampleIndex) (shwap.Sample, error) { panic("implement me") } diff --git a/store/cache/noop.go b/store/cache/noop.go index d777fdb2e4..2ccc87f387 100644 --- a/store/cache/noop.go +++ b/store/cache/noop.go @@ -59,7 +59,7 @@ func (n NoopFile) AxisRoots(context.Context) (*share.AxisRoots, error) { return &share.AxisRoots{}, nil } -func (n NoopFile) Sample(context.Context, int, int) (shwap.Sample, error) { +func (n NoopFile) Sample(context.Context, shwap.SampleIndex) (shwap.Sample, error) { return shwap.Sample{}, nil } diff --git a/store/file/ods.go b/store/file/ods.go index a4f1313a6e..0e7efd88bc 100644 --- a/store/file/ods.go +++ b/store/file/ods.go @@ -228,7 +228,7 @@ func (o *ODS) Close() error { // Sample returns share and corresponding proof for row and column indices. Implementation can // choose which axis to use for proof. Chosen axis for proof should be indicated in the returned // Sample. -func (o *ODS) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { +func (o *ODS) Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) { // Sample proof axis is selected to optimize read performance. // - For the first and second quadrants, we read the row axis because it is more efficient to read // single row than reading full ODS to calculate single column @@ -236,6 +236,11 @@ func (o *ODS) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, err // column than reading full ODS to calculate single row // - For the fourth quadrant, it does not matter which axis we read because we need to read full ODS // to calculate the sample + rowIdx, colIdx, err := idx.Coordinates(o.Size(ctx)) + if err != nil { + return shwap.Sample{}, err + } + axisType, axisIdx, shrIdx := rsmt2d.Row, rowIdx, colIdx if colIdx < o.size()/2 && rowIdx >= o.size()/2 { axisType, axisIdx, shrIdx = rsmt2d.Col, colIdx, rowIdx @@ -246,7 +251,7 @@ func (o *ODS) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, err return shwap.Sample{}, fmt.Errorf("reading axis: %w", err) } - idx, err := shwap.SampleIndexFromCoordinates(rowIdx, shrIdx, o.Size(ctx)) + idx, err = shwap.SampleIndexFromCoordinates(axisIdx, shrIdx, o.Size(ctx)) if err != nil { return shwap.Sample{}, err } diff --git a/store/file/ods_q4.go b/store/file/ods_q4.go index cea6492e84..799a87e48f 100644 --- a/store/file/ods_q4.go +++ b/store/file/ods_q4.go @@ -122,9 +122,13 @@ func (odsq4 *ODSQ4) AxisRoots(ctx context.Context) (*share.AxisRoots, error) { return odsq4.ods.AxisRoots(ctx) } -func (odsq4 *ODSQ4) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { +func (odsq4 *ODSQ4) Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) { + rowIdw, _, err := idx.Coordinates(odsq4.Size(ctx)) + if err != nil { + return shwap.Sample{}, err + } // use native AxisHalf implementation, to read axis from q4 quadrant when possible - half, err := odsq4.AxisHalf(ctx, rsmt2d.Row, rowIdx) + half, err := odsq4.AxisHalf(ctx, rsmt2d.Row, rowIdw) if err != nil { return shwap.Sample{}, fmt.Errorf("reading axis: %w", err) } @@ -133,10 +137,6 @@ func (odsq4 *ODSQ4) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sampl return shwap.Sample{}, fmt.Errorf("extending shares: %w", err) } - idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, odsq4.Size(ctx)) - if err != nil { - return shwap.Sample{}, err - } return shwap.SampleFromShares(shares, rsmt2d.Row, idx) } diff --git a/store/getter.go b/store/getter.go index 83d1441a03..76b4b69f28 100644 --- a/store/getter.go +++ b/store/getter.go @@ -24,7 +24,9 @@ func NewGetter(store *Store) *Getter { return &Getter{store: store} } -func (g *Getter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) { +func (g *Getter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, + indices []shwap.SampleIndex, +) ([]shwap.Sample, error) { acc, err := g.store.GetByHeight(ctx, hdr.Height()) if err != nil { if errors.Is(err, ErrNotFound) { @@ -36,12 +38,7 @@ func (g *Getter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, ind smpls := make([]shwap.Sample, len(indices)) for i, idx := range indices { - rowIdx, colIdx, err := idx.Coordinates(len(hdr.DAH.RowRoots)) - if err != nil { - return nil, err - } - - smpl, err := acc.Sample(ctx, rowIdx, colIdx) + smpl, err := acc.Sample(ctx, idx) if err != nil { return nil, fmt.Errorf("get sample from accessor:%w", err) } diff --git a/store/getter_test.go b/store/getter_test.go index a87e0b2c98..d2b279d991 100644 --- a/store/getter_test.go +++ b/store/getter_test.go @@ -36,14 +36,16 @@ func TestStoreGetter(t *testing.T) { squareSize := int(eds.Width()) for i := 0; i < squareSize; i++ { for j := 0; j < squareSize; j++ { - share, err := sg.GetShare(ctx, eh, i, j) + idx, err := shwap.SampleIndexFromCoordinates(i, j, len(eh.DAH.RowRoots)) require.NoError(t, err) - require.Equal(t, eds.GetCell(uint(i), uint(j)), share.ToBytes()) + smpls, err := sg.GetSamples(ctx, eh, []shwap.SampleIndex{idx}) + require.NoError(t, err) + require.Equal(t, eds.GetCell(uint(i), uint(j)), smpls[0].Share.ToBytes()) } } // doesn't panic on indexes too high - _, err = sg.GetShare(ctx, eh, squareSize, squareSize) + _, err = sg.GetSamples(ctx, eh, []shwap.SampleIndex{shwap.SampleIndex(squareSize * squareSize)}) require.ErrorIs(t, err, shwap.ErrOutOfBounds) })