Skip to content

Commit

Permalink
feat!: change param type of ImportExtendedDataSquare
Browse files Browse the repository at this point in the history
  • Loading branch information
rootulp committed Jan 30, 2024
1 parent d2a29f0 commit add9662
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 104 deletions.
6 changes: 3 additions & 3 deletions extendeddatacrossword_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestRepairExtendedDataSquare(t *testing.T) {
flattened[12], flattened[13] = nil, nil

// Re-import the data square.
eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree)
eds, err := ImportExtendedDataSquare(flattened, codec, DefaultTreeName)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand All @@ -67,7 +67,7 @@ func TestRepairExtendedDataSquare(t *testing.T) {
flattened[12], flattened[13], flattened[14] = nil, nil, nil

// Re-import the data square.
eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree)
eds, err := ImportExtendedDataSquare(flattened, codec, DefaultTreeName)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand Down Expand Up @@ -275,7 +275,7 @@ func BenchmarkRepair(b *testing.B) {
}

// Re-import the data square.
eds, _ = ImportExtendedDataSquare(flattened, codec, NewDefaultTree)
eds, _ = ImportExtendedDataSquare(flattened, codec, DefaultTreeName)

b.StartTimer()

Expand Down
19 changes: 7 additions & 12 deletions extendeddatasquare.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ func (eds *ExtendedDataSquare) UnmarshalJSON(b []byte) error {
return err
}

treeConstructor, err := TreeFn(aux.Tree)
if err != nil {
return err
}

importedEds, err := ImportExtendedDataSquare(aux.DataSquare, codecs[aux.Codec], treeConstructor)
importedEds, err := ImportExtendedDataSquare(aux.DataSquare, codecs[aux.Codec], aux.Tree)
if err != nil {
return err
}
Expand Down Expand Up @@ -96,7 +91,7 @@ func ComputeExtendedDataSquare(
func ImportExtendedDataSquare(
data [][]byte,
codec Codec,
treeCreatorFn TreeConstructorFn,
treeName string,
) (*ExtendedDataSquare, error) {
if len(data) > 4*codec.MaxChunks() {
return nil, errors.New("number of chunks exceeds the maximum")
Expand All @@ -108,14 +103,14 @@ func ImportExtendedDataSquare(
return nil, err
}

ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize))
treeCreatorFn, err := TreeFn(treeName)
if err != nil {
return nil, err
}

treeName := getTreeNameFromConstructorFn(treeCreatorFn)
if treeName == "" {
return nil, errors.New("tree name not found")
ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize))
if err != nil {
return nil, err
}

eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: treeName}
Expand Down Expand Up @@ -249,7 +244,7 @@ func (eds *ExtendedDataSquare) erasureExtendCol(codec Codec, i uint) error {
}

func (eds *ExtendedDataSquare) deepCopy(codec Codec) (ExtendedDataSquare, error) {
imported, err := ImportExtendedDataSquare(eds.Flattened(), codec, eds.createTreeFn)
imported, err := ImportExtendedDataSquare(eds.Flattened(), codec, eds.treeName)
return *imported, err
}

Expand Down
6 changes: 3 additions & 3 deletions extendeddatasquare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ func TestComputeExtendedDataSquare(t *testing.T) {
func TestImportExtendedDataSquare(t *testing.T) {
t.Run("is able to import an EDS", func(t *testing.T) {
eds := createExampleEds(t, shareSize)
got, err := ImportExtendedDataSquare(eds.Flattened(), NewLeoRSCodec(), NewDefaultTree)
got, err := ImportExtendedDataSquare(eds.Flattened(), NewLeoRSCodec(), DefaultTreeName)
assert.NoError(t, err)
assert.Equal(t, eds.Flattened(), got.Flattened())
})
t.Run("returns an error if chunkSize is not a multiple of 64", func(t *testing.T) {
chunk := bytes.Repeat([]byte{1}, 65)
_, err := ImportExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), NewDefaultTree)
_, err := ImportExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), DefaultTreeName)
assert.Error(t, err)
})
}
Expand Down Expand Up @@ -122,7 +122,7 @@ func TestUnmarshalJSON(t *testing.T) {
result, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, codec, treeConstructorFn)
}, codec, treeName)
if err != nil {
panic(err)
}
Expand Down
6 changes: 3 additions & 3 deletions rsmt2d_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestEdsRepairRoundtripSimple(t *testing.T) {
flattened[12], flattened[13] = nil, nil

// Re-import the data square.
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.NewDefaultTree)
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.DefaultTreeName)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestEdsRepairTwice(t *testing.T) {
flattened[12], flattened[13] = nil, nil

// Re-import the data square.
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.NewDefaultTree)
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.DefaultTreeName)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand All @@ -139,7 +139,7 @@ func TestEdsRepairTwice(t *testing.T) {
copy(flattened[1], missing)

// Re-import the data square.
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.NewDefaultTree)
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.DefaultTreeName)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand Down
6 changes: 6 additions & 0 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ func TreeFn(treeName string) (TreeConstructorFn, error) {

return treeFn, nil
}

// removeTreeFn removes a treeConstructorFn by treeName.
// Only use for test cleanup. Proceed with caution.
func removeTreeFn(treeName string) {
treeFns.Delete(treeName)
}
83 changes: 0 additions & 83 deletions tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,89 +107,6 @@ func TestTreeFn(t *testing.T) {
}
}

// TestGetTreeNameFromConstructorFn tests the GetTreeNameFromConstructorFn
// function which fetches tree name by it corresponding tree constructor function.
//
// TODO: When we handle all the breaking changes track in this PR: https://github.com/celestiaorg/rsmt2d/pull/278, should remove this test
func TestGetTreeNameFromConstructorFn(t *testing.T) {
treeName := "testing_get_tree_name_tree"
treeConstructorFn := sudoConstructorFn
invalidTreeName := struct{}{}
invalidCaseTreeName := "invalid_case_tree"
invalidTreeConstructorFn := "invalid constructor fn"

tests := []struct {
name string
treeName string
treeFn TreeConstructorFn
malleate func()
expectGetKey bool
}{
// The tree name is successfully fetched.
{
"get successfully",
treeName,
treeConstructorFn,
func() {
err := RegisterTree(treeName, treeConstructorFn)
require.NoError(t, err)
},
true,
},
// Unable to fetch an unregistered tree name.
{
"get unregisted tree name",
"unregisted_tree_name",
nil,
func() {},
false,
},
// Value (tree constructor function) from the global map iteration is an invalid
// value that cannot be type asserted into TreeConstructorFn type.
{
"get invalid interface value",
"",
nil,
func() {
// Seems like this case has low probability of happening
// since all register has been done through RegisterTree func
// which have strict type check as argument.
treeFns.Store(invalidCaseTreeName, invalidTreeConstructorFn)
},
false,
},
// Key (tree name) from the global map iteration is an invalid value that cannot
// be type asserted into string type.
{
"get invalid interface key",
"",
nil,
func() {
// Seems like this case has low probability of happening
// since all register has been done through RegisterTree func
// which have strict type check as argument.
treeFns.Store(invalidTreeName, treeConstructorFn)
},
false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
test.malleate()

key := getTreeNameFromConstructorFn(test.treeFn)
if !test.expectGetKey {
require.Equal(t, key, "")
} else {
require.Equal(t, test.treeName, key)
}
})

cleanUp(test.treeName)
}
}

// Avoid duplicate with default_tree treeConstructorFn
// registered during init.
func sudoConstructorFn(_ Axis, _ uint) Tree {
Expand Down

0 comments on commit add9662

Please sign in to comment.