Skip to content

Commit

Permalink
fix imports on nameds, start handle non struct types
Browse files Browse the repository at this point in the history
  • Loading branch information
racytech committed Jan 13, 2025
1 parent 9bd6dbf commit 0c9209d
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 86 deletions.
128 changes: 58 additions & 70 deletions cmd/rlpgen/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,43 @@ func decodeLenMismatch(want int) string {
return fmt.Sprintf("return fmt.Errorf(\"error decoded length mismatch, expected: %d, got: %%d\", len(b))", want)
}

// 1. add package to imports if the to-be encoded field is not in the same package
// e.g do not import "github.com/erigontech/erigon/core/types" if the field is types.BlockNonce
func addToImports(named *types.Named) (typ string) {
if named.Obj().Pkg().Name() != pkgSrc.Name() {
_imports[named.Obj().Pkg().Path()] = true
typ = named.Obj().Pkg().Name() + "." + named.Obj().Name()
} else {
typ = named.Obj().Name()
}
return
}

func uint64CastTo(kind types.BasicKind) string {
var cast string
switch kind {
case types.Int16:
cast = "int16"
case types.Int32:
cast = "int32"
case types.Int:
cast = "int"
case types.Int64:
cast = "int64"
case types.Uint16:
cast = "uint16"
case types.Uint32:
cast = "uint32"
case types.Uint:
cast = "uint"
case types.Uint64:
return "i := n"
default:
panic(fmt.Sprintf("unhandled basic kind: %d", kind))
}
return fmt.Sprintf("i := %s(n)", cast)
}

func uintHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string) {
var kind types.BasicKind
if basic, ok := fieldType.(*types.Basic); !ok {
Expand All @@ -73,19 +110,8 @@ func uintHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string
fmt.Fprintf(b2, " }\n")

// decode
var cast string
switch kind {
case types.Int:
cast = "i := int(n)"
case types.Int64:
cast = "i := int64(n)"
case types.Uint:
cast = "i := uint(n)"
case types.Uint64:
default:
panic(fmt.Sprintf("unhandled basic kind: %d", kind))
}
if kind != types.Uint64 {
cast := uint64CastTo(kind)
fmt.Fprintf(b3, " if n, err := s.Uint(); err != nil {\n")
fmt.Fprintf(b3, " %s\n", decodeErrorMsg(fieldName))
fmt.Fprintf(b3, " } else {\n")
Expand Down Expand Up @@ -124,20 +150,7 @@ func uintPtrHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName str
fmt.Fprintf(b2, " }\n")

// decode
var cast string
switch kind {
case types.Int:
cast = "i := int(n)"
case types.Int64:
cast = "i := int64(n)"
case types.Uint:
cast = "i := uint(n)"
case types.Uint64:
cast = "i := n"
default:
panic(fmt.Sprintf("unhandled basic kind: %d", kind))
}

cast := uint64CastTo(kind)
fmt.Fprintf(b3, " if n, err := s.Uint(); err != nil {\n")
fmt.Fprintf(b3, " %s\n", decodeErrorMsg(fieldName))
fmt.Fprintf(b3, " } else {\n")
Expand All @@ -148,11 +161,9 @@ func uintPtrHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName str

func bigIntHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string) {
if named, ok := fieldType.(*types.Named); !ok {
_exit("blockNoncePtrHandle: expected filedType to be Named")
_exit("bigIntHandle: expected filedType to be Named")
} else {
if named.Obj().Pkg().Name() != pkgSrc.Name() { // do not import the package where source type is located
_imports[named.Obj().Pkg().Path()] = true
}
_ = addToImports(named)
}
// size
fmt.Fprintf(b1, " size += rlp.BigIntLenExcludingHead(&obj.%s) + 1\n", fieldName)
Expand All @@ -172,16 +183,15 @@ func bigIntHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName stri

func bigIntPtrHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string) {
if ptr, ok := fieldType.(*types.Pointer); !ok {
_exit("_shortArrayPtrHandle: expected fieldType to be Pointer")
_exit("bigIntPtrHandle: expected fieldType to be Pointer")
} else {
if named, ok := ptr.Elem().(*types.Named); !ok {
_exit("blockNoncePtrHandle: expected filedType to be Named")
_exit("bigIntPtrHandle: expected filedType to be Pointer Named")
} else {
if named.Obj().Pkg().Name() != pkgSrc.Name() { // do not import the package where source type is located
_imports[named.Obj().Pkg().Path()] = true
}
_ = addToImports(named)
}
}

// size
fmt.Fprintf(b1, " size += 1\n")
fmt.Fprintf(b1, " if obj.%s != nil {\n", fieldName)
Expand All @@ -203,12 +213,11 @@ func bigIntPtrHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName s

func uint256Handle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string) {
if named, ok := fieldType.(*types.Named); !ok {
_exit("blockNoncePtrHandle: expected filedType to be Named")
_exit("uint256Handle: expected filedType to be Named")
} else {
if named.Obj().Pkg().Name() != pkgSrc.Name() { // do not import the package where source type is located
_imports[named.Obj().Pkg().Path()] = true
}
_ = addToImports(named)
}

// size
fmt.Fprintf(b1, " size += rlp.Uint256LenExcludingHead(&obj.%s) + 1\n", fieldName)

Expand All @@ -227,14 +236,12 @@ func uint256Handle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName str

func uint256PtrHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string) {
if ptr, ok := fieldType.(*types.Pointer); !ok {
_exit("_shortArrayPtrHandle: expected fieldType to be Pointer")
_exit("uint256PtrHandle: expected fieldType to be Pointer")
} else {
if named, ok := ptr.Elem().(*types.Named); !ok {
_exit("blockNoncePtrHandle: expected filedType to be Named")
_exit("uint256PtrHandle: expected filedType to be Pointer Named")
} else {
if named.Obj().Pkg().Name() != pkgSrc.Name() { // do not import the package where source type is located
_imports[named.Obj().Pkg().Path()] = true
}
_ = addToImports(named)
}
}

Expand Down Expand Up @@ -287,14 +294,9 @@ func _shortArrayPtrHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldN
_exit("_shortArrayPtrHandle: expected fieldType to be Pointer")
} else {
if named, ok := ptr.Elem().(*types.Named); !ok {
_exit("blockNoncePtrHandle: expected filedType to be Pointer Named")
_exit("_shortArrayPtrHandle: expected filedType to be Pointer Named")
} else {
if named.Obj().Pkg().Name() != pkgSrc.Name() { // do not import the package where source type is located
_imports[named.Obj().Pkg().Path()] = true
typ = named.Obj().Pkg().Name() + "." + named.Obj().Name()
} else {
typ = named.Obj().Name()
}
typ = addToImports(named)
}
}

Expand Down Expand Up @@ -386,17 +388,12 @@ func bloomHandle(b1, b2, b3 *bytes.Buffer, _ types.Type, fieldName string) {
func bloomPtrHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string) {
var typ string
if ptr, ok := fieldType.(*types.Pointer); !ok {
_exit("_shortArrayPtrHandle: expected fieldType to be Pointer")
_exit("bloomPtrHandle: expected fieldType to be Pointer")
} else {
if named, ok := ptr.Elem().(*types.Named); !ok {
_exit("blockNoncePtrHandle: expected filedType to be Pointer Named")
_exit("bloomPtrHandle: expected filedType to be Pointer Named")
} else {
if named.Obj().Pkg().Name() != pkgSrc.Name() { // do not import the package where source type is located
_imports[named.Obj().Pkg().Path()] = true
typ = named.Obj().Pkg().Name() + "." + named.Obj().Name()
} else {
typ = named.Obj().Name()
}
typ = addToImports(named)
}
}

Expand Down Expand Up @@ -536,11 +533,7 @@ func _shortArraySliceHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fiel
if named, ok := slc.Elem().(*types.Named); !ok {
_exit("_shortArraySliceHandle: expected filedType to be Slice Named")
} else {
if named.Obj().Pkg().Name() != pkgSrc.Name() {
typ = named.Obj().Pkg().Name() + "." + named.Obj().Name()
} else {
typ = named.Obj().Name()
}
typ = addToImports(named)
}
}

Expand Down Expand Up @@ -591,12 +584,7 @@ func _shortArrayPtrSliceHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, f
if named, ok := ptr.Elem().(*types.Named); !ok {
_exit("_shortArrayPtrSliceHandle: expected filedType to be Slice Pointer Named")
} else {
if named.Obj().Pkg().Name() != pkgSrc.Name() { // do not import the package where source type is located
_imports[named.Obj().Pkg().Path()] = true
typ = named.Obj().Pkg().Name() + "." + named.Obj().Name()
} else {
typ = named.Obj().Name()
}
typ = addToImports(named)
}
}
}
Expand Down
37 changes: 25 additions & 12 deletions cmd/rlpgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func main() {
result = append(result, encodingSize.Bytes()...)
result = append(result, encodeRLP.Bytes()...)
result = append(result, decodeRLP.Bytes()...)
// os.Stdout.Write(result)
os.Stdout.Write(result)
if *writefile {
outfile := fmt.Sprintf("%s/gen_%s_rlp.go", *pkgdir, strings.ToLower(typ.Obj().Name()))
fmt.Println("outfile: ", outfile)
Expand Down Expand Up @@ -173,6 +173,11 @@ func process(typ *types.Named, b1, b2, b3 *bytes.Buffer) error {
}

func findType(scope *types.Scope, typename string) (*types.Named, error) {
// fmt.Println("TYPENAME: ", typename)
// names := scope.Names()
// for _, s := range names {
// fmt.Println("obj: ", s)
// }
obj := scope.Lookup(typename)
if obj == nil {
return nil, fmt.Errorf("no such identifier: %s", typename)
Expand All @@ -181,22 +186,30 @@ func findType(scope *types.Scope, typename string) (*types.Named, error) {
if !ok {
return nil, errors.New("not a type")
}
named := typ.Type().(*types.Named)
_, ok = named.Underlying().(*types.Struct)
if !ok {
return nil, errors.New("not a struct type")
if named, ok := typ.Type().(*types.Named); ok {
return named, nil
}
return named, nil
return nil, errors.New("not a named type")
}

func addEncodeLogic(b1, b2, b3 *bytes.Buffer, namedType *types.Named) error {
_struct := namedType.Underlying().(*types.Struct)
for i := 0; i < _struct.NumFields(); i++ {
func addEncodeLogic(b1, b2, b3 *bytes.Buffer, named *types.Named) error {

if _struct, ok := named.Underlying().(*types.Struct); ok {
for i := 0; i < _struct.NumFields(); i++ {

strTyp := matchTypeToString(_struct.Field(i).Type(), "")
// fmt.Println("-+-", strTyp)
strTyp := matchTypeToString(_struct.Field(i).Type(), "")
// fmt.Println("-+-", strTyp)

matchStrTypeToFunc(strTyp)(b1, b2, b3, _struct.Field(i).Type(), _struct.Field(i).Name())
matchStrTypeToFunc(strTyp)(b1, b2, b3, _struct.Field(i).Type(), _struct.Field(i).Name())
}
} else { // user named types that are not structs, could be:
// 1. type aliases
// - type aliases for basic types, e.g type MyInt int
// - type aliases for user named types, e.g type MyHash common.Hash (could be struct as well!)
// 2. slice types
// - slice of basice types, e.g type MyInts []int
// - slice of user named types, e.g type ReceiptsForStorage []*ReceiptForStorage
}

return nil
}
4 changes: 2 additions & 2 deletions cmd/rlpgen/matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ func matchTypeToString(fieldType types.Type, in string) string {

func matchStrTypeToFunc(strType string) handle {
switch strType {
case "int", "int64", "uint", "uint64": // test/add int16, int32, uint16, uint32
case "int16", "int32", "int", "int64", "uint16", "uint32", "uint", "uint64":
return handlers["uint64"]
case "*int", "*int64", "*uint", "*uint64":
case "*int16", "*int32", "*int", "*int64", "*uint16", "*uint32", "*uint", "*uint64":
return handlers["*uint64"]
default:
if fn, ok := handlers[strType]; ok {
Expand Down
4 changes: 2 additions & 2 deletions core/types/encdec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,13 +707,13 @@ func TestTestingStruct(t *testing.T) {
// enc := randTestingStruct(tr)
// buf.Reset()

// if err := enc.EncodeRLP2(&buf); err != nil {
// if err := enc.EncodeRLP(&buf); err != nil {
// t.Errorf("error: TestingStruct.EncodeRLP(): %v", err)
// }

// s := rlp.NewStream(bytes.NewReader(buf.Bytes()), 0)
// dec := &TestingStruct{}
// if err := dec.DecodeRLP2(s); err != nil {
// if err := dec.DecodeRLP(s); err != nil {
// t.Errorf("error: TestingStruct.DecodeRLP(): %v", err)
// panic(err)
// }
Expand Down

0 comments on commit 0c9209d

Please sign in to comment.