From da28faa6c5dd83e182be385e789685e1a9205066 Mon Sep 17 00:00:00 2001 From: lwfreitas Date: Thu, 16 Feb 2023 15:33:28 +0000 Subject: [PATCH 1/9] Draft for generics. --- generator/generator.go | 178 +++++++++++++++++++++++++++--------- generator/generator_test.go | 63 ++++++++----- generator/generics.go | 115 +++++++++++++++++++++++ generator/types.go | 16 ++-- 4 files changed, 299 insertions(+), 73 deletions(-) create mode 100644 generator/generics.go diff --git a/generator/generator.go b/generator/generator.go index b60958fa..3d49cf62 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -20,7 +20,7 @@ import ( "github.com/hexdigest/gowrap/printer" ) -//Generator generates decorators for the interface types +// Generator generates decorators for the interface types type Generator struct { Options @@ -30,6 +30,8 @@ type Generator struct { dstPackage *packages.Package methods methodsList interfaceType string + genericsTypes string + genericsParams string localPrefix string } @@ -85,13 +87,15 @@ type TemplateInputInterface struct { Name string // Type of the interface, with package name qualifier (e.g. sort.Interface) Type string + // Generics of the interface when using generics + Generics TemplateInputGenerics // Methods name keyed map of method information Methods map[string]Method } type methodsList map[string]Method -//Options of the NewGenerator constructor +// Options of the NewGenerator constructor type Options struct { //InterfaceName is a name of interface type InterfaceName string @@ -131,7 +135,7 @@ type Options struct { var errEmptyInterface = errors.New("interface has no methods") var errUnexportedMethod = errors.New("unexported method") -//NewGenerator returns Generator initialized with options +// NewGenerator returns Generator initialized with options func NewGenerator(options Options) (*Generator, error) { if options.Funcs == nil { options.Funcs = make(template.FuncMap) @@ -185,10 +189,11 @@ func NewGenerator(options Options) (*Generator, error) { options.Imports = append(options.Imports, `"`+srcPackage.PkgPath+`"`) } - methods, imports, err := findInterface(fs, srcPackage, srcPackageAST, options.InterfaceName) + types, methods, imports, err := findInterface(fs, srcPackage, srcPackageAST, options.InterfaceName, nil) if err != nil { return nil, errors.Wrap(err, "failed to parse interface declaration") } + genericsTypes, genericsParams := types.buildVars() if len(methods) == 0 { return nil, errEmptyInterface @@ -209,6 +214,8 @@ func NewGenerator(options Options) (*Generator, error) { srcPackage: srcPackage, dstPackage: dstPackage, interfaceType: interfaceType, + genericsTypes: genericsTypes, + genericsParams: genericsParams, methods: methods, localPrefix: options.LocalPrefix, }, nil @@ -250,7 +257,7 @@ func makePackage(path string) (*packages.Package, error) { }, nil } -//Generate generates code using header and body templates +// Generate generates code using header and body templates func (g Generator) Generate(w io.Writer) error { buf := bytes.NewBuffer([]byte{}) @@ -266,7 +273,11 @@ func (g Generator) Generate(w io.Writer) error { err = g.bodyTemplate.Execute(buf, TemplateInputs{ Interface: TemplateInputInterface{ - Name: g.Options.InterfaceName, + Name: g.Options.InterfaceName, + Generics: TemplateInputGenerics{ + Types: g.genericsTypes, + Params: g.genericsParams, + }, Type: g.interfaceType, Methods: g.methods, }, @@ -290,40 +301,43 @@ func (g Generator) Generate(w io.Writer) error { var errInterfaceNotFound = errors.New("interface type declaration not found") // findInterface looks for the interface declaration in the given directory -// and returns a list of the interface's methods and a list of imports from the file +// and returns the generic params if exists, a list of the interface's methods, and a list of imports from the file // where interface type declaration was found -func findInterface(fs *token.FileSet, currentPackage *packages.Package, p *ast.Package, interfaceName string) (methods methodsList, imports []*ast.ImportSpec, err error) { - var found bool - var types []*ast.TypeSpec - var it *ast.InterfaceType - +func findInterface(fs *token.FileSet, currentPackage *packages.Package, p *ast.Package, interfaceName string, genericParams genericsParams) (genericsTypes genericsTypes, methods methodsList, imports []*ast.ImportSpec, err error) { //looking for the source interface declaration in all files in the dir //while doing this we also store all found type declarations to check if some of the //interface methods use unexported types - for _, f := range p.Files { - for _, ts := range typeSpecs(f) { - types = append(types, ts) + ts, imports, types := iterateFiles(p, interfaceName) + if ts == nil { + return nil, nil, nil, errors.Wrap(errInterfaceNotFound, interfaceName) + } - if i, ok := ts.Type.(*ast.InterfaceType); ok { - if ts.Name.Name == interfaceName && !found { - imports = f.Imports - it = i - found = true - } - } + genericsTypes = genericsTypesBuild(ts) + + if it, ok := ts.Type.(*ast.InterfaceType); ok { + methods, err = processInterface(fs, currentPackage, it, types, p.Name, imports, genericsTypes, genericParams) + if err != nil { + return nil, nil, nil, err } } - if !found { - return nil, nil, errors.Wrap(errInterfaceNotFound, interfaceName) - } + return genericsTypes, methods, imports, err +} - methods, err = processInterface(fs, currentPackage, it, types, p.Name, imports) - if err != nil { - return nil, nil, err +func iterateFiles(p *ast.Package, name string) (selectedType *ast.TypeSpec, imports []*ast.ImportSpec, types []*ast.TypeSpec) { + for _, f := range p.Files { + if f != nil { + for _, ts := range typeSpecs(f) { + types = append(types, ts) + if ts.Name.Name == name { + selectedType = ts + imports = f.Imports + return + } + } + } } - - return methods, imports, err + return } func typeSpecs(f *ast.File) []*ast.TypeSpec { @@ -342,13 +356,86 @@ func typeSpecs(f *ast.File) []*ast.TypeSpec { return result } -func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *ast.InterfaceType, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methods methodsList, err error) { +func getEmbeddedMethods(t ast.Expr, fs *token.FileSet, currentPackage *packages.Package, types []*ast.TypeSpec, pr typePrinter, typesPrefix string, imports []*ast.ImportSpec, params genericsParams) (param genericsParam, methods methodsList, err error) { + switch v := t.(type) { + case *ast.SelectorExpr: + if x, ok := v.X.(*ast.Ident); ok && x != nil { + param.Name, err = pr.PrintType(x) + if err != nil { + return + } + } + + methods, err = processSelector(fs, currentPackage, v, imports, params) + return + + case *ast.Ident: + param.Name, err = pr.PrintType(v) + if err != nil { + return + } + methods, err = processIdent(fs, currentPackage, v, types, typesPrefix, imports, params) + return + } + return +} + +func processEmbedded(t ast.Expr, fs *token.FileSet, currentPackage *packages.Package, types []*ast.TypeSpec, pr typePrinter, typesPrefix string, imports []*ast.ImportSpec, params genericsParams) (param genericsParam, embeddedMethods methodsList, err error) { + var x ast.Expr + var genericsParam bool + + switch v := t.(type) { + case *ast.IndexExpr: + x = v.X + genericsParam = true + + param, _, err = processEmbedded(v.Index, fs, currentPackage, types, pr, typesPrefix, imports, params) + if err != nil { + return + } + if param.Name != "" { + params = append(params, param) + } + + case *ast.IndexListExpr: + x = v.X + genericsParam = true + + if v.Indices != nil { + for _, index := range v.Indices { + param, _, err = processEmbedded(index, fs, currentPackage, types, pr, typesPrefix, imports, params) + if err != nil { + return + } + if param.Name != "" { + params = append(params, param) + } + } + } + default: + x = v + } + + param, embeddedMethods, err = getEmbeddedMethods(x, fs, currentPackage, types, pr, typesPrefix, imports, params) + if err != nil { + return + } + + if genericsParam { + param.Params = params + } + return +} + +func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *ast.InterfaceType, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec, genericsTypes genericsTypes, params genericsParams) (methods methodsList, err error) { if it.Methods == nil { return nil, nil } methods = make(methodsList, len(it.Methods.List)) + pr := printer.New(fs, types, typesPrefix) + for _, field := range it.Methods.List { var embeddedMethods methodsList var err error @@ -356,15 +443,15 @@ func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *a switch v := field.Type.(type) { case *ast.FuncType: var method *Method - method, err = NewMethod(field.Names[0].Name, field, printer.New(fs, types, typesPrefix)) + + method, err = NewMethod(field.Names[0].Name, field, pr, genericsTypes, params) if err == nil { methods[field.Names[0].Name] = *method continue } - case *ast.SelectorExpr: - embeddedMethods, err = processSelector(fs, currentPackage, v, imports) - case *ast.Ident: - embeddedMethods, err = processIdent(fs, currentPackage, v, types, typesPrefix, imports) + + default: + _, embeddedMethods, err = processEmbedded(v, fs, currentPackage, types, pr, typesPrefix, imports, params) } if err != nil { @@ -380,8 +467,8 @@ func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *a return methods, nil } -func processSelector(fs *token.FileSet, currentPackage *packages.Package, se *ast.SelectorExpr, imports []*ast.ImportSpec) (methodsList, error) { - interfaceName := se.Sel.Name +func processSelector(fs *token.FileSet, currentPackage *packages.Package, se *ast.SelectorExpr, imports []*ast.ImportSpec, params genericsParams) (methodsList, error) { + selectedName := se.Sel.Name packageSelector := se.X.(*ast.Ident).Name importPath, err := findImportPathForName(packageSelector, imports, currentPackage) @@ -399,13 +486,13 @@ func processSelector(fs *token.FileSet, currentPackage *packages.Package, se *as return nil, errors.Wrap(err, "failed to import package") } - methods, _, err := findInterface(fs, p, astPkg, interfaceName) + _, methods, _, err := findInterface(fs, p, astPkg, selectedName, params) return methods, err } -//mergeMethods merges two methods list. Retains overlapping methods from the -//parent list +// mergeMethods merges two methods list. Retains overlapping methods from the +// parent list func mergeMethods(methods, embeddedMethods methodsList) (methodsList, error) { if methods == nil || embeddedMethods == nil { return methods, nil @@ -426,8 +513,9 @@ func mergeMethods(methods, embeddedMethods methodsList) (methodsList, error) { var errEmbeddedInterfaceNotFound = errors.New("embedded interface not found") var errNotAnInterface = errors.New("embedded type is not an interface") -func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Ident, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methodsList, error) { +func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Ident, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec, params genericsParams) (methodsList, error) { var embeddedInterface *ast.InterfaceType + var genericsTypes genericsTypes for _, t := range types { if t.Name.Name == i.Name { var ok bool @@ -435,15 +523,17 @@ func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Id if !ok { return nil, errors.Wrap(errNotAnInterface, t.Name.Name) } + + genericsTypes = genericsTypesBuild(t) break } } if embeddedInterface == nil { - return nil, errors.Wrap(errEmbeddedInterfaceNotFound, i.Name) + return nil, nil } - return processInterface(fs, currentPackage, embeddedInterface, types, typesPrefix, imports) + return processInterface(fs, currentPackage, embeddedInterface, types, typesPrefix, imports, genericsTypes, params) } var errUnknownSelector = errors.New("unknown selector") diff --git a/generator/generator_test.go b/generator/generator_test.go index d3d02686..a7538001 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -111,11 +111,12 @@ func Test_findImportPathForName(t *testing.T) { func Test_processIdent(t *testing.T) { type args struct { - fs *token.FileSet - i *ast.Ident - types []*ast.TypeSpec - typesPrefix string - imports []*ast.ImportSpec + fs *token.FileSet + i *ast.Ident + types []*ast.TypeSpec + typesPrefix string + imports []*ast.ImportSpec + genericsParams genericsParams } tests := []struct { name string @@ -162,7 +163,7 @@ func Test_processIdent(t *testing.T) { mc := minimock.NewController(t) defer mc.Wait(time.Second) - got1, err := processIdent(tt.args.fs, nil, tt.args.i, tt.args.types, tt.args.typesPrefix, tt.args.imports) + got1, err := processIdent(tt.args.fs, nil, tt.args.i, tt.args.types, tt.args.typesPrefix, tt.args.imports, tt.args.genericsParams) assert.Equal(t, tt.want1, got1, "processIdent returned unexpected result") @@ -254,10 +255,11 @@ func Test_mergeMethods(t *testing.T) { func Test_processSelector(t *testing.T) { type args struct { - fs *token.FileSet - cp *packages.Package - se *ast.SelectorExpr - imports []*ast.ImportSpec + fs *token.FileSet + cp *packages.Package + se *ast.SelectorExpr + imports []*ast.ImportSpec + genericsParams genericsParams } tests := []struct { name string @@ -300,7 +302,7 @@ func Test_processSelector(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got1, err := processSelector(tt.args.fs, tt.args.cp, tt.args.se, tt.args.imports) + got1, err := processSelector(tt.args.fs, tt.args.cp, tt.args.se, tt.args.imports, tt.args.genericsParams) assert.Equal(t, tt.want1, got1, "processSelector returned unexpected result") @@ -318,12 +320,14 @@ func Test_processSelector(t *testing.T) { func Test_processInterface(t *testing.T) { type args struct { - fs *token.FileSet - cp *packages.Package - it *ast.InterfaceType - types []*ast.TypeSpec - typesPrefix string - imports []*ast.ImportSpec + fs *token.FileSet + cp *packages.Package + it *ast.InterfaceType + types []*ast.TypeSpec + typesPrefix string + imports []*ast.ImportSpec + genericsTypes genericsTypes + genericsParams genericsParams } tests := []struct { name string @@ -383,11 +387,25 @@ func Test_processInterface(t *testing.T) { want1: methodsList{}, wantErr: false, }, + { + name: "index list expression with identifier", + args: args{ + fs: token.NewFileSet(), + it: &ast.InterfaceType{Methods: &ast.FieldList{List: []*ast.Field{ + { + Type: &ast.IndexListExpr{X: &ast.Ident{Name: "Embedded"}}, + }, + }}}, + types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Embedded"}, Type: &ast.InterfaceType{}}}, + }, + want1: methodsList{}, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got1, err := processInterface(tt.args.fs, tt.args.cp, tt.args.it, tt.args.types, tt.args.typesPrefix, tt.args.imports) + got1, err := processInterface(tt.args.fs, tt.args.cp, tt.args.it, tt.args.types, tt.args.typesPrefix, tt.args.imports, tt.args.genericsTypes, tt.args.genericsParams) assert.Equal(t, tt.want1, got1, "processInterface returned unexpected result") @@ -418,9 +436,10 @@ func Test_typeSpecs(t *testing.T) { func Test_findInterface(t *testing.T) { type args struct { - fs *token.FileSet - p *ast.Package - interfaceName string + fs *token.FileSet + p *ast.Package + interfaceName string + genericsParams genericsParams } tests := []struct { name string @@ -459,7 +478,7 @@ func Test_findInterface(t *testing.T) { mc := minimock.NewController(t) defer mc.Wait(time.Second) - got1, _, err := findInterface(tt.args.fs, nil, tt.args.p, tt.args.interfaceName) + _, got1, _, err := findInterface(tt.args.fs, nil, tt.args.p, tt.args.interfaceName, tt.args.genericsParams) assert.Equal(t, tt.want1, got1, "findInterface returned unexpected result") diff --git a/generator/generics.go b/generator/generics.go new file mode 100644 index 00000000..882dde14 --- /dev/null +++ b/generator/generics.go @@ -0,0 +1,115 @@ +package generator + +import ( + "go/ast" + "strings" +) + +const ( + genericsSeparator = ", " + + genericsSquareBracketStart = "[" + genericsSquareBracketEnd = "]" +) + +// TemplateInputGenerics subset of generics interface information used for template generation +type TemplateInputGenerics struct { + // Types of the interface when using generics (e.g. [I, O any]) + Types string + + // Params of the interface when using generics (e.g. [I, O]) + Params string +} + +type genericsParams []genericsParam + +type genericsParam struct { + Name string + Params genericsParams +} + +func (g genericsParam) String() string { + name := g.Name + var subParamNames []string + for _, subParam := range g.Params { + subParamNames = append(subParamNames, subParam.String()) + } + if len(g.Params) > 0 { + name += genericsSquareBracketStart + strings.Join(subParamNames, genericsSeparator) + genericsSquareBracketEnd + } + return name +} + +type genericsTypes []genericsType + +type genericsType struct { + Type string + Names []string +} + +func genericsWithBracketsBuild(t string) string { + if t != "" { + t = genericsSquareBracketStart + t + genericsSquareBracketEnd + } + return t +} + +func (g genericsTypes) buildVars() (string, string) { + var types, typesSep string + var params, paramsSep string + + for _, genType := range g { + var paramsByType, paramsByTypeSep string + + for _, name := range genType.Names { + paramsByType += paramsByTypeSep + name + params += paramsSep + name + paramsSep = genericsSeparator + paramsByTypeSep = genericsSeparator + } + + if paramsByType != "" { + types += typesSep + paramsByType + " " + genType.Type + typesSep = genericsSeparator + } + } + + return genericsWithBracketsBuild(types), genericsWithBracketsBuild(params) +} + +func genericsTypesBuild(ts *ast.TypeSpec) (types genericsTypes) { + if ts.TypeParams != nil { + for _, param := range ts.TypeParams.List { + if param != nil { + if gpt, ok := param.Type.(*ast.Ident); ok { + var paramNames []string + for _, name := range param.Names { + if name != nil { + paramNames = append(paramNames, name.Name) + } + } + types = append(types, genericsType{ + Type: gpt.Name, + Names: paramNames, + }) + } + } + } + } + return +} + +func genericsBuildParamString(typeStr string, genericsTypes genericsTypes, genericsParams genericsParams) string { + c := 0 + for _, genType := range genericsTypes { + for _, name := range genType.Names { + if name == typeStr { + if len(genericsParams) > c { + return genericsParams[c].String() + } + } + c++ + } + } + return typeStr +} diff --git a/generator/types.go b/generator/types.go index 4e7b6da0..83db5a20 100644 --- a/generator/types.go +++ b/generator/types.go @@ -66,7 +66,7 @@ func (p Param) Pass() string { } // NewMethod returns pointer to Signature struct or error -func NewMethod(name string, fi *ast.Field, printer typePrinter) (*Method, error) { +func NewMethod(name string, fi *ast.Field, printer typePrinter, genericsSpec genericsTypes, genericsParams genericsParams) (*Method, error) { f, ok := fi.Type.(*ast.FuncType) if !ok { return nil, fmt.Errorf("%q is not a method", name) @@ -105,11 +105,11 @@ func NewMethod(name string, fi *ast.Field, printer typePrinter) (*Method, error) var err error - m.Params, err = makeParams(f.Params, usedNames, printer) + m.Params, err = makeParams(f.Params, usedNames, printer, genericsSpec, genericsParams) if err != nil { return nil, err } - m.Results, err = makeParams(f.Results, usedNames, printer) + m.Results, err = makeParams(f.Results, usedNames, printer, genericsSpec, genericsParams) if err != nil { return nil, err } @@ -126,7 +126,7 @@ func NewMethod(name string, fi *ast.Field, printer typePrinter) (*Method, error) } // NewParam returns Param struct -func NewParam(name string, fi *ast.Field, usedNames map[string]bool, printer typePrinter) (*Param, error) { +func NewParam(name string, fi *ast.Field, usedNames map[string]bool, printer typePrinter, genericsSpec genericsTypes, genericsParams genericsParams) (*Param, error) { typ := fi.Type if name == "" || usedNames[name] { name = genName(typePrefix(typ), 1, usedNames) @@ -139,6 +139,8 @@ func NewParam(name string, fi *ast.Field, usedNames map[string]bool, printer typ return nil, err } + typeStr = genericsBuildParamString(typeStr, genericsSpec, genericsParams) + _, variadic := typ.(*ast.Ellipsis) p := &Param{ Name: name, @@ -162,7 +164,7 @@ func NewParam(name string, fi *ast.Field, usedNames map[string]bool, printer typ return p, nil } -func makeParams(params *ast.FieldList, usedNames map[string]bool, printer typePrinter) (ParamsSlice, error) { +func makeParams(params *ast.FieldList, usedNames map[string]bool, printer typePrinter, genericsSpec genericsTypes, genericsParams genericsParams) (ParamsSlice, error) { if params == nil { return nil, nil } @@ -172,14 +174,14 @@ func makeParams(params *ast.FieldList, usedNames map[string]bool, printer typePr //for anonymous parameters we generate params and results names //based on their type if p.Names == nil { - param, err := NewParam("", p, usedNames, printer) + param, err := NewParam("", p, usedNames, printer, genericsSpec, genericsParams) if err != nil { return nil, err } result = append(result, *param) } else { for _, ident := range p.Names { - param, err := NewParam(ident.Name, p, usedNames, printer) + param, err := NewParam(ident.Name, p, usedNames, printer, genericsSpec, genericsParams) if err != nil { return nil, err } From 21f988cbeb0278f85184e873f764d798fe9193b3 Mon Sep 17 00:00:00 2001 From: lwfreitas Date: Thu, 16 Feb 2023 23:11:38 +0000 Subject: [PATCH 2/9] Generics feature. --- generator/generator.go | 1 - generator/generator_test.go | 24 ------------------------ 2 files changed, 25 deletions(-) diff --git a/generator/generator.go b/generator/generator.go index 3d49cf62..c4a618f1 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -510,7 +510,6 @@ func mergeMethods(methods, embeddedMethods methodsList) (methodsList, error) { return result, nil } -var errEmbeddedInterfaceNotFound = errors.New("embedded interface not found") var errNotAnInterface = errors.New("embedded type is not an interface") func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Ident, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec, params genericsParams) (methodsList, error) { diff --git a/generator/generator_test.go b/generator/generator_test.go index a7538001..750cc08f 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -137,17 +137,6 @@ func Test_processIdent(t *testing.T) { assert.Equal(t, errNotAnInterface, errors.Cause(err)) }, }, - { - name: "embedded interface not found", - args: args{ - i: &ast.Ident{Name: "name"}, - types: []*ast.TypeSpec{}, - }, - wantErr: true, - inspectErr: func(err error, t *testing.T) { - assert.Equal(t, errEmbeddedInterfaceNotFound, errors.Cause(err)) - }, - }, { name: "embedded interface found", args: args{ @@ -360,19 +349,6 @@ func Test_processInterface(t *testing.T) { }, wantErr: true, }, - { - name: "identifier", - args: args{ - fs: token.NewFileSet(), - it: &ast.InterfaceType{Methods: &ast.FieldList{List: []*ast.Field{ - { - Names: []*ast.Ident{{Name: "methodName"}}, - Type: &ast.Ident{Name: "unknown"}, - }, - }}}, - }, - wantErr: true, - }, { name: "identifier with embedded methods", args: args{ From 3a35407718de5281814c624fd6f48313053bcb48 Mon Sep 17 00:00:00 2001 From: lwfreitas Date: Fri, 17 Feb 2023 04:31:04 +0000 Subject: [PATCH 3/9] Fix nomenclature of genericParams and genericTypes --- generator/generator.go | 54 ++++++++++++++++++------------------- generator/generator_test.go | 2 +- generator/generics.go | 30 ++++++++++----------- generator/types.go | 16 +++++------ 4 files changed, 51 insertions(+), 51 deletions(-) diff --git a/generator/generator.go b/generator/generator.go index c4a618f1..14ae6107 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -303,7 +303,7 @@ var errInterfaceNotFound = errors.New("interface type declaration not found") // findInterface looks for the interface declaration in the given directory // and returns the generic params if exists, a list of the interface's methods, and a list of imports from the file // where interface type declaration was found -func findInterface(fs *token.FileSet, currentPackage *packages.Package, p *ast.Package, interfaceName string, genericParams genericsParams) (genericsTypes genericsTypes, methods methodsList, imports []*ast.ImportSpec, err error) { +func findInterface(fs *token.FileSet, currentPackage *packages.Package, p *ast.Package, interfaceName string, genericParams genericParams) (genericTypes genericTypes, methods methodsList, imports []*ast.ImportSpec, err error) { //looking for the source interface declaration in all files in the dir //while doing this we also store all found type declarations to check if some of the //interface methods use unexported types @@ -312,16 +312,16 @@ func findInterface(fs *token.FileSet, currentPackage *packages.Package, p *ast.P return nil, nil, nil, errors.Wrap(errInterfaceNotFound, interfaceName) } - genericsTypes = genericsTypesBuild(ts) + genericTypes = genericTypesBuild(ts) if it, ok := ts.Type.(*ast.InterfaceType); ok { - methods, err = processInterface(fs, currentPackage, it, types, p.Name, imports, genericsTypes, genericParams) + methods, err = processInterface(fs, currentPackage, it, types, p.Name, imports, genericTypes, genericParams) if err != nil { return nil, nil, nil, err } } - return genericsTypes, methods, imports, err + return genericTypes, methods, imports, err } func iterateFiles(p *ast.Package, name string) (selectedType *ast.TypeSpec, imports []*ast.ImportSpec, types []*ast.TypeSpec) { @@ -356,7 +356,7 @@ func typeSpecs(f *ast.File) []*ast.TypeSpec { return result } -func getEmbeddedMethods(t ast.Expr, fs *token.FileSet, currentPackage *packages.Package, types []*ast.TypeSpec, pr typePrinter, typesPrefix string, imports []*ast.ImportSpec, params genericsParams) (param genericsParam, methods methodsList, err error) { +func getEmbeddedMethods(t ast.Expr, fs *token.FileSet, currentPackage *packages.Package, types []*ast.TypeSpec, pr typePrinter, typesPrefix string, imports []*ast.ImportSpec, params genericParams) (param genericParam, methods methodsList, err error) { switch v := t.(type) { case *ast.SelectorExpr: if x, ok := v.X.(*ast.Ident); ok && x != nil { @@ -380,35 +380,35 @@ func getEmbeddedMethods(t ast.Expr, fs *token.FileSet, currentPackage *packages. return } -func processEmbedded(t ast.Expr, fs *token.FileSet, currentPackage *packages.Package, types []*ast.TypeSpec, pr typePrinter, typesPrefix string, imports []*ast.ImportSpec, params genericsParams) (param genericsParam, embeddedMethods methodsList, err error) { +func processEmbedded(t ast.Expr, fs *token.FileSet, currentPackage *packages.Package, types []*ast.TypeSpec, pr typePrinter, typesPrefix string, imports []*ast.ImportSpec, genericParams genericParams) (genericParam genericParam, embeddedMethods methodsList, err error) { var x ast.Expr - var genericsParam bool + var hasGenericsParams bool switch v := t.(type) { case *ast.IndexExpr: x = v.X - genericsParam = true + hasGenericsParams = true - param, _, err = processEmbedded(v.Index, fs, currentPackage, types, pr, typesPrefix, imports, params) + genericParam, _, err = processEmbedded(v.Index, fs, currentPackage, types, pr, typesPrefix, imports, genericParams) if err != nil { return } - if param.Name != "" { - params = append(params, param) + if genericParam.Name != "" { + genericParams = append(genericParams, genericParam) } case *ast.IndexListExpr: x = v.X - genericsParam = true + hasGenericsParams = true if v.Indices != nil { for _, index := range v.Indices { - param, _, err = processEmbedded(index, fs, currentPackage, types, pr, typesPrefix, imports, params) + genericParam, _, err = processEmbedded(index, fs, currentPackage, types, pr, typesPrefix, imports, genericParams) if err != nil { return } - if param.Name != "" { - params = append(params, param) + if genericParam.Name != "" { + genericParams = append(genericParams, genericParam) } } } @@ -416,18 +416,18 @@ func processEmbedded(t ast.Expr, fs *token.FileSet, currentPackage *packages.Pac x = v } - param, embeddedMethods, err = getEmbeddedMethods(x, fs, currentPackage, types, pr, typesPrefix, imports, params) + genericParam, embeddedMethods, err = getEmbeddedMethods(x, fs, currentPackage, types, pr, typesPrefix, imports, genericParam.Params) if err != nil { return } - if genericsParam { - param.Params = params + if hasGenericsParams { + genericParam.Params = genericParam.Params } return } -func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *ast.InterfaceType, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec, genericsTypes genericsTypes, params genericsParams) (methods methodsList, err error) { +func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *ast.InterfaceType, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec, genericsTypes genericTypes, genericParams genericParams) (methods methodsList, err error) { if it.Methods == nil { return nil, nil } @@ -444,14 +444,14 @@ func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *a case *ast.FuncType: var method *Method - method, err = NewMethod(field.Names[0].Name, field, pr, genericsTypes, params) + method, err = NewMethod(field.Names[0].Name, field, pr, genericsTypes, genericParams) if err == nil { methods[field.Names[0].Name] = *method continue } default: - _, embeddedMethods, err = processEmbedded(v, fs, currentPackage, types, pr, typesPrefix, imports, params) + _, embeddedMethods, err = processEmbedded(v, fs, currentPackage, types, pr, typesPrefix, imports, genericParams) } if err != nil { @@ -467,7 +467,7 @@ func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *a return methods, nil } -func processSelector(fs *token.FileSet, currentPackage *packages.Package, se *ast.SelectorExpr, imports []*ast.ImportSpec, params genericsParams) (methodsList, error) { +func processSelector(fs *token.FileSet, currentPackage *packages.Package, se *ast.SelectorExpr, imports []*ast.ImportSpec, genericParams genericParams) (methodsList, error) { selectedName := se.Sel.Name packageSelector := se.X.(*ast.Ident).Name @@ -486,7 +486,7 @@ func processSelector(fs *token.FileSet, currentPackage *packages.Package, se *as return nil, errors.Wrap(err, "failed to import package") } - _, methods, _, err := findInterface(fs, p, astPkg, selectedName, params) + _, methods, _, err := findInterface(fs, p, astPkg, selectedName, genericParams) return methods, err } @@ -512,9 +512,9 @@ func mergeMethods(methods, embeddedMethods methodsList) (methodsList, error) { var errNotAnInterface = errors.New("embedded type is not an interface") -func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Ident, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec, params genericsParams) (methodsList, error) { +func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Ident, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec, genericParams genericParams) (methodsList, error) { var embeddedInterface *ast.InterfaceType - var genericsTypes genericsTypes + var genericsTypes genericTypes for _, t := range types { if t.Name.Name == i.Name { var ok bool @@ -523,7 +523,7 @@ func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Id return nil, errors.Wrap(errNotAnInterface, t.Name.Name) } - genericsTypes = genericsTypesBuild(t) + genericsTypes = genericTypesBuild(t) break } } @@ -532,7 +532,7 @@ func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Id return nil, nil } - return processInterface(fs, currentPackage, embeddedInterface, types, typesPrefix, imports, genericsTypes, params) + return processInterface(fs, currentPackage, embeddedInterface, types, typesPrefix, imports, genericsTypes, genericParams) } var errUnknownSelector = errors.New("unknown selector") diff --git a/generator/generator_test.go b/generator/generator_test.go index 750cc08f..c12446c8 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -315,7 +315,7 @@ func Test_processInterface(t *testing.T) { types []*ast.TypeSpec typesPrefix string imports []*ast.ImportSpec - genericsTypes genericsTypes + genericsTypes genericTypes genericsParams genericsParams } tests := []struct { diff --git a/generator/generics.go b/generator/generics.go index 882dde14..512e10ee 100644 --- a/generator/generics.go +++ b/generator/generics.go @@ -21,14 +21,14 @@ type TemplateInputGenerics struct { Params string } -type genericsParams []genericsParam +type genericParams []genericParam -type genericsParam struct { +type genericParam struct { Name string - Params genericsParams + Params genericParams } -func (g genericsParam) String() string { +func (g genericParam) String() string { name := g.Name var subParamNames []string for _, subParam := range g.Params { @@ -40,9 +40,9 @@ func (g genericsParam) String() string { return name } -type genericsTypes []genericsType +type genericTypes []genericType -type genericsType struct { +type genericType struct { Type string Names []string } @@ -54,7 +54,7 @@ func genericsWithBracketsBuild(t string) string { return t } -func (g genericsTypes) buildVars() (string, string) { +func (g genericTypes) buildVars() (string, string) { var types, typesSep string var params, paramsSep string @@ -77,7 +77,7 @@ func (g genericsTypes) buildVars() (string, string) { return genericsWithBracketsBuild(types), genericsWithBracketsBuild(params) } -func genericsTypesBuild(ts *ast.TypeSpec) (types genericsTypes) { +func genericTypesBuild(ts *ast.TypeSpec) (types genericTypes) { if ts.TypeParams != nil { for _, param := range ts.TypeParams.List { if param != nil { @@ -88,7 +88,7 @@ func genericsTypesBuild(ts *ast.TypeSpec) (types genericsTypes) { paramNames = append(paramNames, name.Name) } } - types = append(types, genericsType{ + types = append(types, genericType{ Type: gpt.Name, Names: paramNames, }) @@ -99,16 +99,16 @@ func genericsTypesBuild(ts *ast.TypeSpec) (types genericsTypes) { return } -func genericsBuildParamString(typeStr string, genericsTypes genericsTypes, genericsParams genericsParams) string { - c := 0 - for _, genType := range genericsTypes { +func genericBuildParamString(typeStr string, genericTypes genericTypes, genericParams genericParams) string { + i := 0 + for _, genType := range genericTypes { for _, name := range genType.Names { if name == typeStr { - if len(genericsParams) > c { - return genericsParams[c].String() + if len(genericParams) > i { + return genericParams[i].String() } } - c++ + i++ } } return typeStr diff --git a/generator/types.go b/generator/types.go index 83db5a20..0b7acb73 100644 --- a/generator/types.go +++ b/generator/types.go @@ -66,7 +66,7 @@ func (p Param) Pass() string { } // NewMethod returns pointer to Signature struct or error -func NewMethod(name string, fi *ast.Field, printer typePrinter, genericsSpec genericsTypes, genericsParams genericsParams) (*Method, error) { +func NewMethod(name string, fi *ast.Field, printer typePrinter, genericTypes genericTypes, genericParams genericParams) (*Method, error) { f, ok := fi.Type.(*ast.FuncType) if !ok { return nil, fmt.Errorf("%q is not a method", name) @@ -105,11 +105,11 @@ func NewMethod(name string, fi *ast.Field, printer typePrinter, genericsSpec gen var err error - m.Params, err = makeParams(f.Params, usedNames, printer, genericsSpec, genericsParams) + m.Params, err = makeParams(f.Params, usedNames, printer, genericTypes, genericParams) if err != nil { return nil, err } - m.Results, err = makeParams(f.Results, usedNames, printer, genericsSpec, genericsParams) + m.Results, err = makeParams(f.Results, usedNames, printer, genericTypes, genericParams) if err != nil { return nil, err } @@ -126,7 +126,7 @@ func NewMethod(name string, fi *ast.Field, printer typePrinter, genericsSpec gen } // NewParam returns Param struct -func NewParam(name string, fi *ast.Field, usedNames map[string]bool, printer typePrinter, genericsSpec genericsTypes, genericsParams genericsParams) (*Param, error) { +func NewParam(name string, fi *ast.Field, usedNames map[string]bool, printer typePrinter, genericTypes genericTypes, genericParams genericParams) (*Param, error) { typ := fi.Type if name == "" || usedNames[name] { name = genName(typePrefix(typ), 1, usedNames) @@ -139,7 +139,7 @@ func NewParam(name string, fi *ast.Field, usedNames map[string]bool, printer typ return nil, err } - typeStr = genericsBuildParamString(typeStr, genericsSpec, genericsParams) + typeStr = genericBuildParamString(typeStr, genericTypes, genericParams) _, variadic := typ.(*ast.Ellipsis) p := &Param{ @@ -164,7 +164,7 @@ func NewParam(name string, fi *ast.Field, usedNames map[string]bool, printer typ return p, nil } -func makeParams(params *ast.FieldList, usedNames map[string]bool, printer typePrinter, genericsSpec genericsTypes, genericsParams genericsParams) (ParamsSlice, error) { +func makeParams(params *ast.FieldList, usedNames map[string]bool, printer typePrinter, genericTypes genericTypes, genericParams genericParams) (ParamsSlice, error) { if params == nil { return nil, nil } @@ -174,14 +174,14 @@ func makeParams(params *ast.FieldList, usedNames map[string]bool, printer typePr //for anonymous parameters we generate params and results names //based on their type if p.Names == nil { - param, err := NewParam("", p, usedNames, printer, genericsSpec, genericsParams) + param, err := NewParam("", p, usedNames, printer, genericTypes, genericParams) if err != nil { return nil, err } result = append(result, *param) } else { for _, ident := range p.Names { - param, err := NewParam(ident.Name, p, usedNames, printer, genericsSpec, genericsParams) + param, err := NewParam(ident.Name, p, usedNames, printer, genericTypes, genericParams) if err != nil { return nil, err } From 026a3eb6d6ba913143ec205a302e28f2e3f3cf58 Mon Sep 17 00:00:00 2001 From: lwfreitas Date: Fri, 17 Feb 2023 08:59:06 +0000 Subject: [PATCH 4/9] Fix embedded methods --- generator/generator.go | 125 +++++++++++++++++++++++------------- generator/generator_test.go | 4 +- 2 files changed, 84 insertions(+), 45 deletions(-) diff --git a/generator/generator.go b/generator/generator.go index 14ae6107..45d72b20 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -93,8 +93,6 @@ type TemplateInputInterface struct { Methods map[string]Method } -type methodsList map[string]Method - // Options of the NewGenerator constructor type Options struct { //InterfaceName is a name of interface type @@ -132,6 +130,30 @@ type Options struct { LocalPrefix string } +type methodsList map[string]Method + +type processInput struct { + fileSet *token.FileSet + currentPackage *packages.Package + pacakge *ast.Package + targetName string + genericParams genericParams +} + +type targetProcessInput struct { + processInput + types []*ast.TypeSpec + typesPrefix string + imports []*ast.ImportSpec + genericsTypes genericTypes +} + +type processOutput struct { + genericTypes genericTypes + methods methodsList + imports []*ast.ImportSpec +} + var errEmptyInterface = errors.New("interface has no methods") var errUnexportedMethod = errors.New("unexported method") @@ -189,23 +211,28 @@ func NewGenerator(options Options) (*Generator, error) { options.Imports = append(options.Imports, `"`+srcPackage.PkgPath+`"`) } - types, methods, imports, err := findInterface(fs, srcPackage, srcPackageAST, options.InterfaceName, nil) + output, err := findTarget(processInput{ + fileSet: fs, + currentPackage: srcPackage, + pacakge: srcPackageAST, + targetName: options.InterfaceName, + }) if err != nil { return nil, errors.Wrap(err, "failed to parse interface declaration") } - genericsTypes, genericsParams := types.buildVars() + genericsTypes, genericsParams := output.genericTypes.buildVars() - if len(methods) == 0 { + if len(output.methods) == 0 { return nil, errEmptyInterface } - for _, m := range methods { + for _, m := range output.methods { if srcPackageAST.Name != "" && []rune(m.Name)[0] == []rune(strings.ToLower(m.Name))[0] { return nil, errors.Wrap(errUnexportedMethod, m.Name) } } - options.Imports = append(options.Imports, makeImports(imports)...) + options.Imports = append(options.Imports, makeImports(output.imports)...) return &Generator{ Options: options, @@ -216,7 +243,7 @@ func NewGenerator(options Options) (*Generator, error) { interfaceType: interfaceType, genericsTypes: genericsTypes, genericsParams: genericsParams, - methods: methods, + methods: output.methods, localPrefix: options.LocalPrefix, }, nil } @@ -298,30 +325,31 @@ func (g Generator) Generate(w io.Writer) error { return err } -var errInterfaceNotFound = errors.New("interface type declaration not found") +var errTargetNotFound = errors.New("target declaration not found") -// findInterface looks for the interface declaration in the given directory -// and returns the generic params if exists, a list of the interface's methods, and a list of imports from the file -// where interface type declaration was found -func findInterface(fs *token.FileSet, currentPackage *packages.Package, p *ast.Package, interfaceName string, genericParams genericParams) (genericTypes genericTypes, methods methodsList, imports []*ast.ImportSpec, err error) { - //looking for the source interface declaration in all files in the dir - //while doing this we also store all found type declarations to check if some of the - //interface methods use unexported types - ts, imports, types := iterateFiles(p, interfaceName) +func findTarget(input processInput) (output processOutput, err error) { + ts, imports, types := iterateFiles(input.pacakge, input.targetName) if ts == nil { - return nil, nil, nil, errors.Wrap(errInterfaceNotFound, interfaceName) + return processOutput{}, errors.Wrap(errTargetNotFound, input.targetName) } - genericTypes = genericTypesBuild(ts) + output.imports = imports + output.genericTypes = genericTypesBuild(ts) if it, ok := ts.Type.(*ast.InterfaceType); ok { - methods, err = processInterface(fs, currentPackage, it, types, p.Name, imports, genericTypes, genericParams) + output.methods, err = processInterface(it, targetProcessInput{ + processInput: input, + types: types, + typesPrefix: input.pacakge.Name, + imports: output.imports, + genericsTypes: output.genericTypes, + }) if err != nil { - return nil, nil, nil, err + return processOutput{}, err } } - return genericTypes, methods, imports, err + return } func iterateFiles(p *ast.Package, name string) (selectedType *ast.TypeSpec, imports []*ast.ImportSpec, types []*ast.TypeSpec) { @@ -356,7 +384,7 @@ func typeSpecs(f *ast.File) []*ast.TypeSpec { return result } -func getEmbeddedMethods(t ast.Expr, fs *token.FileSet, currentPackage *packages.Package, types []*ast.TypeSpec, pr typePrinter, typesPrefix string, imports []*ast.ImportSpec, params genericParams) (param genericParam, methods methodsList, err error) { +func getEmbeddedMethods(t ast.Expr, pr typePrinter, input targetProcessInput) (param genericParam, methods methodsList, err error) { switch v := t.(type) { case *ast.SelectorExpr: if x, ok := v.X.(*ast.Ident); ok && x != nil { @@ -366,7 +394,7 @@ func getEmbeddedMethods(t ast.Expr, fs *token.FileSet, currentPackage *packages. } } - methods, err = processSelector(fs, currentPackage, v, imports, params) + methods, err = processSelector(v, input) return case *ast.Ident: @@ -374,22 +402,24 @@ func getEmbeddedMethods(t ast.Expr, fs *token.FileSet, currentPackage *packages. if err != nil { return } - methods, err = processIdent(fs, currentPackage, v, types, typesPrefix, imports, params) + + methods, err = processIdent(v, input) return } return } -func processEmbedded(t ast.Expr, fs *token.FileSet, currentPackage *packages.Package, types []*ast.TypeSpec, pr typePrinter, typesPrefix string, imports []*ast.ImportSpec, genericParams genericParams) (genericParam genericParam, embeddedMethods methodsList, err error) { +func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput) (genericParam genericParam, embeddedMethods methodsList, err error) { var x ast.Expr var hasGenericsParams bool + var genericParams genericParams switch v := t.(type) { case *ast.IndexExpr: x = v.X hasGenericsParams = true - genericParam, _, err = processEmbedded(v.Index, fs, currentPackage, types, pr, typesPrefix, imports, genericParams) + genericParam, _, err = processEmbedded(v.Index, pr, input) if err != nil { return } @@ -403,7 +433,7 @@ func processEmbedded(t ast.Expr, fs *token.FileSet, currentPackage *packages.Pac if v.Indices != nil { for _, index := range v.Indices { - genericParam, _, err = processEmbedded(index, fs, currentPackage, types, pr, typesPrefix, imports, genericParams) + genericParam, _, err = processEmbedded(index, pr, input) if err != nil { return } @@ -416,25 +446,27 @@ func processEmbedded(t ast.Expr, fs *token.FileSet, currentPackage *packages.Pac x = v } - genericParam, embeddedMethods, err = getEmbeddedMethods(x, fs, currentPackage, types, pr, typesPrefix, imports, genericParam.Params) + input.genericParams = genericParams + genericParam, embeddedMethods, err = getEmbeddedMethods(x, pr, input) if err != nil { return } if hasGenericsParams { - genericParam.Params = genericParam.Params + genericParam.Params = genericParams } + return } -func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *ast.InterfaceType, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec, genericsTypes genericTypes, genericParams genericParams) (methods methodsList, err error) { +func processInterface(it *ast.InterfaceType, targetInput targetProcessInput) (methods methodsList, err error) { if it.Methods == nil { return nil, nil } methods = make(methodsList, len(it.Methods.List)) - pr := printer.New(fs, types, typesPrefix) + pr := printer.New(targetInput.fileSet, targetInput.types, targetInput.typesPrefix) for _, field := range it.Methods.List { var embeddedMethods methodsList @@ -444,14 +476,14 @@ func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *a case *ast.FuncType: var method *Method - method, err = NewMethod(field.Names[0].Name, field, pr, genericsTypes, genericParams) + method, err = NewMethod(field.Names[0].Name, field, pr, targetInput.genericsTypes, targetInput.genericParams) if err == nil { methods[field.Names[0].Name] = *method continue } default: - _, embeddedMethods, err = processEmbedded(v, fs, currentPackage, types, pr, typesPrefix, imports, genericParams) + _, embeddedMethods, err = processEmbedded(v, pr, targetInput) } if err != nil { @@ -467,28 +499,34 @@ func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *a return methods, nil } -func processSelector(fs *token.FileSet, currentPackage *packages.Package, se *ast.SelectorExpr, imports []*ast.ImportSpec, genericParams genericParams) (methodsList, error) { +func processSelector(se *ast.SelectorExpr, input targetProcessInput) (methodsList, error) { selectedName := se.Sel.Name packageSelector := se.X.(*ast.Ident).Name - importPath, err := findImportPathForName(packageSelector, imports, currentPackage) + importPath, err := findImportPathForName(packageSelector, input.imports, input.currentPackage) if err != nil { return nil, errors.Wrapf(err, "unable to find package %s", packageSelector) } - p, ok := currentPackage.Imports[importPath] + p, ok := input.currentPackage.Imports[importPath] if !ok { return nil, fmt.Errorf("unable to find package %s", packageSelector) } - astPkg, err := pkg.AST(fs, p) + astPkg, err := pkg.AST(input.fileSet, p) if err != nil { return nil, errors.Wrap(err, "failed to import package") } - _, methods, _, err := findInterface(fs, p, astPkg, selectedName, genericParams) + output, err := findTarget(processInput{ + fileSet: input.fileSet, + currentPackage: p, + pacakge: astPkg, + targetName: selectedName, + genericParams: input.genericParams, + }) - return methods, err + return output.methods, err } // mergeMethods merges two methods list. Retains overlapping methods from the @@ -512,10 +550,10 @@ func mergeMethods(methods, embeddedMethods methodsList) (methodsList, error) { var errNotAnInterface = errors.New("embedded type is not an interface") -func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Ident, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec, genericParams genericParams) (methodsList, error) { +func processIdent(i *ast.Ident, input targetProcessInput) (methodsList, error) { var embeddedInterface *ast.InterfaceType var genericsTypes genericTypes - for _, t := range types { + for _, t := range input.types { if t.Name.Name == i.Name { var ok bool embeddedInterface, ok = t.Type.(*ast.InterfaceType) @@ -532,7 +570,8 @@ func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Id return nil, nil } - return processInterface(fs, currentPackage, embeddedInterface, types, typesPrefix, imports, genericsTypes, genericParams) + input.genericsTypes = genericsTypes + return processInterface(embeddedInterface, input) } var errUnknownSelector = errors.New("unknown selector") diff --git a/generator/generator_test.go b/generator/generator_test.go index c12446c8..fb65b5fd 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -430,7 +430,7 @@ func Test_findInterface(t *testing.T) { args: args{p: &ast.Package{}}, wantErr: true, inspectErr: func(err error, t *testing.T) { - assert.Equal(t, errInterfaceNotFound, errors.Cause(err)) + assert.Equal(t, errTargetNotFound, errors.Cause(err)) }, }, { @@ -454,7 +454,7 @@ func Test_findInterface(t *testing.T) { mc := minimock.NewController(t) defer mc.Wait(time.Second) - _, got1, _, err := findInterface(tt.args.fs, nil, tt.args.p, tt.args.interfaceName, tt.args.genericsParams) + _, got1, _, err := findTarget(tt.args.fs, nil, tt.args.p, tt.args.interfaceName, tt.args.genericsParams) assert.Equal(t, tt.want1, got1, "findInterface returned unexpected result") From c5e20e9fe61748a24e0c6184ce8da270a515b063 Mon Sep 17 00:00:00 2001 From: lwfreitas Date: Sat, 18 Feb 2023 01:48:27 +0000 Subject: [PATCH 5/9] More nomenclature fixes --- generator/generator.go | 49 ++++++------ generator/generator_test.go | 148 +++++++++++++++++++++--------------- generator/generics.go | 24 +++--- generator/types.go | 2 +- 4 files changed, 123 insertions(+), 100 deletions(-) diff --git a/generator/generator.go b/generator/generator.go index 45d72b20..6f369280 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -30,8 +30,8 @@ type Generator struct { dstPackage *packages.Package methods methodsList interfaceType string - genericsTypes string - genericsParams string + genericTypes string + genericParams string localPrefix string } @@ -135,17 +135,17 @@ type methodsList map[string]Method type processInput struct { fileSet *token.FileSet currentPackage *packages.Package - pacakge *ast.Package + astPackage *ast.Package targetName string genericParams genericParams } type targetProcessInput struct { processInput - types []*ast.TypeSpec - typesPrefix string - imports []*ast.ImportSpec - genericsTypes genericTypes + types []*ast.TypeSpec + typesPrefix string + imports []*ast.ImportSpec + genericTypes genericTypes } type processOutput struct { @@ -214,13 +214,12 @@ func NewGenerator(options Options) (*Generator, error) { output, err := findTarget(processInput{ fileSet: fs, currentPackage: srcPackage, - pacakge: srcPackageAST, + astPackage: srcPackageAST, targetName: options.InterfaceName, }) if err != nil { return nil, errors.Wrap(err, "failed to parse interface declaration") } - genericsTypes, genericsParams := output.genericTypes.buildVars() if len(output.methods) == 0 { return nil, errEmptyInterface @@ -234,6 +233,8 @@ func NewGenerator(options Options) (*Generator, error) { options.Imports = append(options.Imports, makeImports(output.imports)...) + genericTypes, genericParams := output.genericTypes.buildVars() + return &Generator{ Options: options, headerTemplate: headerTemplate, @@ -241,8 +242,8 @@ func NewGenerator(options Options) (*Generator, error) { srcPackage: srcPackage, dstPackage: dstPackage, interfaceType: interfaceType, - genericsTypes: genericsTypes, - genericsParams: genericsParams, + genericTypes: genericTypes, + genericParams: genericParams, methods: output.methods, localPrefix: options.LocalPrefix, }, nil @@ -302,8 +303,8 @@ func (g Generator) Generate(w io.Writer) error { Interface: TemplateInputInterface{ Name: g.Options.InterfaceName, Generics: TemplateInputGenerics{ - Types: g.genericsTypes, - Params: g.genericsParams, + Types: g.genericTypes, + Params: g.genericParams, }, Type: g.interfaceType, Methods: g.methods, @@ -328,21 +329,21 @@ func (g Generator) Generate(w io.Writer) error { var errTargetNotFound = errors.New("target declaration not found") func findTarget(input processInput) (output processOutput, err error) { - ts, imports, types := iterateFiles(input.pacakge, input.targetName) + ts, imports, types := iterateFiles(input.astPackage, input.targetName) if ts == nil { return processOutput{}, errors.Wrap(errTargetNotFound, input.targetName) } output.imports = imports - output.genericTypes = genericTypesBuild(ts) + output.genericTypes = buildGenericTypesFromSpec(ts) if it, ok := ts.Type.(*ast.InterfaceType); ok { output.methods, err = processInterface(it, targetProcessInput{ - processInput: input, - types: types, - typesPrefix: input.pacakge.Name, - imports: output.imports, - genericsTypes: output.genericTypes, + processInput: input, + types: types, + typesPrefix: input.astPackage.Name, + imports: output.imports, + genericTypes: output.genericTypes, }) if err != nil { return processOutput{}, err @@ -476,7 +477,7 @@ func processInterface(it *ast.InterfaceType, targetInput targetProcessInput) (me case *ast.FuncType: var method *Method - method, err = NewMethod(field.Names[0].Name, field, pr, targetInput.genericsTypes, targetInput.genericParams) + method, err = NewMethod(field.Names[0].Name, field, pr, targetInput.genericTypes, targetInput.genericParams) if err == nil { methods[field.Names[0].Name] = *method continue @@ -521,7 +522,7 @@ func processSelector(se *ast.SelectorExpr, input targetProcessInput) (methodsLis output, err := findTarget(processInput{ fileSet: input.fileSet, currentPackage: p, - pacakge: astPkg, + astPackage: astPkg, targetName: selectedName, genericParams: input.genericParams, }) @@ -561,7 +562,7 @@ func processIdent(i *ast.Ident, input targetProcessInput) (methodsList, error) { return nil, errors.Wrap(errNotAnInterface, t.Name.Name) } - genericsTypes = genericTypesBuild(t) + genericsTypes = buildGenericTypesFromSpec(t) break } } @@ -570,7 +571,7 @@ func processIdent(i *ast.Ident, input targetProcessInput) (methodsList, error) { return nil, nil } - input.genericsTypes = genericsTypes + input.genericTypes = genericsTypes return processInterface(embeddedInterface, input) } diff --git a/generator/generator_test.go b/generator/generator_test.go index fb65b5fd..c7310a29 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -111,12 +111,8 @@ func Test_findImportPathForName(t *testing.T) { func Test_processIdent(t *testing.T) { type args struct { - fs *token.FileSet - i *ast.Ident - types []*ast.TypeSpec - typesPrefix string - imports []*ast.ImportSpec - genericsParams genericsParams + i *ast.Ident + input targetProcessInput } tests := []struct { name string @@ -129,8 +125,10 @@ func Test_processIdent(t *testing.T) { { name: "not an interface", args: args{ - i: &ast.Ident{Name: "name"}, - types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "name"}, Type: &ast.StructType{}}}, + i: &ast.Ident{Name: "name"}, + input: targetProcessInput{ + types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "name"}, Type: &ast.StructType{}}}, + }, }, wantErr: true, inspectErr: func(err error, t *testing.T) { @@ -140,8 +138,10 @@ func Test_processIdent(t *testing.T) { { name: "embedded interface found", args: args{ - i: &ast.Ident{Name: "name"}, - types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "name"}, Type: &ast.InterfaceType{}}}, + i: &ast.Ident{Name: "name"}, + input: targetProcessInput{ + types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "name"}, Type: &ast.InterfaceType{}}}, + }, }, wantErr: false, }, @@ -152,7 +152,7 @@ func Test_processIdent(t *testing.T) { mc := minimock.NewController(t) defer mc.Wait(time.Second) - got1, err := processIdent(tt.args.fs, nil, tt.args.i, tt.args.types, tt.args.typesPrefix, tt.args.imports, tt.args.genericsParams) + got1, err := processIdent(tt.args.i, tt.args.input) assert.Equal(t, tt.want1, got1, "processIdent returned unexpected result") @@ -244,11 +244,8 @@ func Test_mergeMethods(t *testing.T) { func Test_processSelector(t *testing.T) { type args struct { - fs *token.FileSet - cp *packages.Package - se *ast.SelectorExpr - imports []*ast.ImportSpec - genericsParams genericsParams + se *ast.SelectorExpr + input targetProcessInput } tests := []struct { name string @@ -262,36 +259,47 @@ func Test_processSelector(t *testing.T) { name: "import with name not found", args: args{ se: &ast.SelectorExpr{X: &ast.Ident{Name: "unknown"}, Sel: &ast.Ident{Name: "unknown"}}, - cp: &packages.Package{Imports: nil}, + input: targetProcessInput{ + processInput: processInput{ + currentPackage: &packages.Package{Imports: nil}, + }, + }, }, wantErr: true, }, { name: "import not found", args: args{ - se: &ast.SelectorExpr{X: &ast.Ident{Name: "unknownpackage"}, Sel: &ast.Ident{Name: "Unknown"}}, - imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "unknown_path"}}}, - cp: &packages.Package{Imports: nil}, + se: &ast.SelectorExpr{X: &ast.Ident{Name: "unknownpackage"}, Sel: &ast.Ident{Name: "Unknown"}}, + input: targetProcessInput{ + processInput: processInput{ + currentPackage: &packages.Package{Imports: nil}, + }, + imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "unknown_path"}}}, + }, }, wantErr: true, }, { name: "import failed", args: args{ - se: &ast.SelectorExpr{X: &ast.Ident{Name: "io"}, Sel: &ast.Ident{Name: "UnknownInterface"}}, - imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "io"}}}, - fs: token.NewFileSet(), - cp: &packages.Package{Imports: map[string]*packages.Package{ - "io": {}, - }}, + se: &ast.SelectorExpr{X: &ast.Ident{Name: "io"}, Sel: &ast.Ident{Name: "UnknownInterface"}}, + input: targetProcessInput{ + processInput: processInput{ + fileSet: token.NewFileSet(), + currentPackage: &packages.Package{Imports: map[string]*packages.Package{ + "io": {}, + }}, + }, + imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "io"}}}, + }, }, wantErr: true, }, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got1, err := processSelector(tt.args.fs, tt.args.cp, tt.args.se, tt.args.imports, tt.args.genericsParams) + got1, err := processSelector(tt.args.se, tt.args.input) assert.Equal(t, tt.want1, got1, "processSelector returned unexpected result") @@ -309,14 +317,8 @@ func Test_processSelector(t *testing.T) { func Test_processInterface(t *testing.T) { type args struct { - fs *token.FileSet - cp *packages.Package - it *ast.InterfaceType - types []*ast.TypeSpec - typesPrefix string - imports []*ast.ImportSpec - genericsTypes genericTypes - genericsParams genericsParams + it *ast.InterfaceType + targetInput targetProcessInput } tests := []struct { name string @@ -329,8 +331,12 @@ func Test_processInterface(t *testing.T) { { name: "func type", args: args{ - fs: token.NewFileSet(), it: &ast.InterfaceType{Methods: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "methodName"}}, Type: &ast.FuncType{Params: &ast.FieldList{}}}}}}, + targetInput: targetProcessInput{ + processInput: processInput{ + fileSet: token.NewFileSet(), + }, + }, }, want1: methodsList{"methodName": Method{Name: "methodName", Params: []Param{}}}, wantErr: false, @@ -338,27 +344,35 @@ func Test_processInterface(t *testing.T) { { name: "selector expression", args: args{ - fs: token.NewFileSet(), - cp: &packages.Package{Imports: nil}, it: &ast.InterfaceType{Methods: &ast.FieldList{List: []*ast.Field{ { Names: []*ast.Ident{{Name: "methodName"}}, Type: &ast.SelectorExpr{X: &ast.Ident{Name: "unknown"}, Sel: &ast.Ident{Name: "Interface"}}, }, }}}, + targetInput: targetProcessInput{ + processInput: processInput{ + fileSet: token.NewFileSet(), + currentPackage: &packages.Package{Imports: nil}, + }, + }, }, wantErr: true, }, { name: "identifier with embedded methods", args: args{ - fs: token.NewFileSet(), it: &ast.InterfaceType{Methods: &ast.FieldList{List: []*ast.Field{ { Type: &ast.Ident{Name: "Embedded"}, }, }}}, - types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Embedded"}, Type: &ast.InterfaceType{}}}, + targetInput: targetProcessInput{ + processInput: processInput{ + fileSet: token.NewFileSet(), + }, + types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Embedded"}, Type: &ast.InterfaceType{}}}, + }, }, want1: methodsList{}, wantErr: false, @@ -366,13 +380,17 @@ func Test_processInterface(t *testing.T) { { name: "index list expression with identifier", args: args{ - fs: token.NewFileSet(), it: &ast.InterfaceType{Methods: &ast.FieldList{List: []*ast.Field{ { Type: &ast.IndexListExpr{X: &ast.Ident{Name: "Embedded"}}, }, }}}, - types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Embedded"}, Type: &ast.InterfaceType{}}}, + targetInput: targetProcessInput{ + processInput: processInput{ + fileSet: token.NewFileSet(), + }, + types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Embedded"}, Type: &ast.InterfaceType{}}}, + }, }, want1: methodsList{}, wantErr: false, @@ -381,7 +399,7 @@ func Test_processInterface(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got1, err := processInterface(tt.args.fs, tt.args.cp, tt.args.it, tt.args.types, tt.args.typesPrefix, tt.args.imports, tt.args.genericsTypes, tt.args.genericsParams) + got1, err := processInterface(tt.args.it, tt.args.targetInput) assert.Equal(t, tt.want1, got1, "processInterface returned unexpected result") @@ -410,12 +428,9 @@ func Test_typeSpecs(t *testing.T) { assert.Equal(t, expected, specs, "typeSpecs returned unexpected result") } -func Test_findInterface(t *testing.T) { +func Test_findTarget(t *testing.T) { type args struct { - fs *token.FileSet - p *ast.Package - interfaceName string - genericsParams genericsParams + input processInput } tests := []struct { name string @@ -426,8 +441,12 @@ func Test_findInterface(t *testing.T) { inspectErr func(err error, t *testing.T) }{ { - name: "not found", - args: args{p: &ast.Package{}}, + name: "not found", + args: args{ + input: processInput{ + astPackage: &ast.Package{}, + }, + }, wantErr: true, inspectErr: func(err error, t *testing.T) { assert.Equal(t, errTargetNotFound, errors.Cause(err)) @@ -436,14 +455,16 @@ func Test_findInterface(t *testing.T) { { name: "found", args: args{ - p: &ast.Package{Files: map[string]*ast.File{ - "file.go": { - Decls: []ast.Decl{&ast.GenDecl{Tok: token.TYPE, Specs: []ast.Spec{&ast.TypeSpec{ - Name: &ast.Ident{Name: "Interface"}, - Type: &ast.InterfaceType{}, - }}}}, - }}}, - interfaceName: "Interface", + input: processInput{ + astPackage: &ast.Package{Files: map[string]*ast.File{ + "file.go": { + Decls: []ast.Decl{&ast.GenDecl{Tok: token.TYPE, Specs: []ast.Spec{&ast.TypeSpec{ + Name: &ast.Ident{Name: "Interface"}, + Type: &ast.InterfaceType{}, + }}}}, + }}}, + targetName: "Interface", + }, }, wantErr: false, }, @@ -454,17 +475,18 @@ func Test_findInterface(t *testing.T) { mc := minimock.NewController(t) defer mc.Wait(time.Second) - _, got1, _, err := findTarget(tt.args.fs, nil, tt.args.p, tt.args.interfaceName, tt.args.genericsParams) + got1, err := findTarget(tt.args.input) - assert.Equal(t, tt.want1, got1, "findInterface returned unexpected result") + assert.Equal(t, tt.want1, got1.methods, "findInterface returned unexpected result") if tt.wantErr { if assert.Error(t, err) && tt.inspectErr != nil { tt.inspectErr(err, t) } - } else { - assert.NoError(t, err) + return } + + assert.NoError(t, err) }) } } diff --git a/generator/generics.go b/generator/generics.go index 512e10ee..bc163f9a 100644 --- a/generator/generics.go +++ b/generator/generics.go @@ -6,10 +6,10 @@ import ( ) const ( - genericsSeparator = ", " + genericSeparator = ", " - genericsSquareBracketStart = "[" - genericsSquareBracketEnd = "]" + genericSquareBracketStart = "[" + genericSquareBracketEnd = "]" ) // TemplateInputGenerics subset of generics interface information used for template generation @@ -35,7 +35,7 @@ func (g genericParam) String() string { subParamNames = append(subParamNames, subParam.String()) } if len(g.Params) > 0 { - name += genericsSquareBracketStart + strings.Join(subParamNames, genericsSeparator) + genericsSquareBracketEnd + name += genericSquareBracketStart + strings.Join(subParamNames, genericSeparator) + genericSquareBracketEnd } return name } @@ -47,9 +47,9 @@ type genericType struct { Names []string } -func genericsWithBracketsBuild(t string) string { +func buildGenericsWithBrackets(t string) string { if t != "" { - t = genericsSquareBracketStart + t + genericsSquareBracketEnd + t = genericSquareBracketStart + t + genericSquareBracketEnd } return t } @@ -64,20 +64,20 @@ func (g genericTypes) buildVars() (string, string) { for _, name := range genType.Names { paramsByType += paramsByTypeSep + name params += paramsSep + name - paramsSep = genericsSeparator - paramsByTypeSep = genericsSeparator + paramsSep = genericSeparator + paramsByTypeSep = genericSeparator } if paramsByType != "" { types += typesSep + paramsByType + " " + genType.Type - typesSep = genericsSeparator + typesSep = genericSeparator } } - return genericsWithBracketsBuild(types), genericsWithBracketsBuild(params) + return buildGenericsWithBrackets(types), buildGenericsWithBrackets(params) } -func genericTypesBuild(ts *ast.TypeSpec) (types genericTypes) { +func buildGenericTypesFromSpec(ts *ast.TypeSpec) (types genericTypes) { if ts.TypeParams != nil { for _, param := range ts.TypeParams.List { if param != nil { @@ -99,7 +99,7 @@ func genericTypesBuild(ts *ast.TypeSpec) (types genericTypes) { return } -func genericBuildParamString(typeStr string, genericTypes genericTypes, genericParams genericParams) string { +func buildGenericParamsString(typeStr string, genericTypes genericTypes, genericParams genericParams) string { i := 0 for _, genType := range genericTypes { for _, name := range genType.Names { diff --git a/generator/types.go b/generator/types.go index 0b7acb73..654f6acd 100644 --- a/generator/types.go +++ b/generator/types.go @@ -139,7 +139,7 @@ func NewParam(name string, fi *ast.Field, usedNames map[string]bool, printer typ return nil, err } - typeStr = genericBuildParamString(typeStr, genericTypes, genericParams) + typeStr = buildGenericParamsString(typeStr, genericTypes, genericParams) _, variadic := typ.(*ast.Ellipsis) p := &Param{ From 314faa57e71264f784926aaeaa739a51fc493fff Mon Sep 17 00:00:00 2001 From: lwfreitas Date: Sat, 18 Feb 2023 17:36:44 +0000 Subject: [PATCH 6/9] Unit tests for generics --- generator/generics_test.go | 169 +++++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 generator/generics_test.go diff --git a/generator/generics_test.go b/generator/generics_test.go new file mode 100644 index 00000000..0c8973df --- /dev/null +++ b/generator/generics_test.go @@ -0,0 +1,169 @@ +package generator + +import ( + "go/ast" + "reflect" + "testing" +) + +func Test_genericParam_String(t *testing.T) { + type fields struct { + Name string + Params genericParams + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "struct with generic params", + fields: fields{ + Name: "somepkg.SomeGenericStruct", + Params: genericParams{ + { + Name: "int", + }, + { + Name: "string", + }, + }, + }, + want: "somepkg.SomeGenericStruct[int, string]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := genericParam{ + Name: tt.fields.Name, + Params: tt.fields.Params, + } + if got := g.String(); got != tt.want { + t.Errorf("genericParam.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_genericTypes_buildVars(t *testing.T) { + tests := []struct { + name string + g genericTypes + want string + want1 string + }{ + { + name: "[T any]", + g: genericTypes{ + { + Names: []string{"T"}, + Type: "any", + }, + }, + want: "[T any]", + want1: "[T]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := tt.g.buildVars() + if got != tt.want { + t.Errorf("genericTypes.buildVars() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("genericTypes.buildVars() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func Test_buildGenericTypesFromSpec(t *testing.T) { + type args struct { + ts *ast.TypeSpec + } + tests := []struct { + name string + args args + wantTypes genericTypes + }{ + { + name: "", + args: args{ + ts: &ast.TypeSpec{ + TypeParams: &ast.FieldList{ + List: []*ast.Field{ + { + Type: &ast.Ident{ + Name: "any", + }, + Names: []*ast.Ident{ + { + Name: "I", + }, + { + Name: "O", + }, + }, + }, + }, + }, + }, + }, + wantTypes: genericTypes{ + { + Type: "any", + Names: []string{"I", "O"}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotTypes := buildGenericTypesFromSpec(tt.args.ts); !reflect.DeepEqual(gotTypes, tt.wantTypes) { + t.Errorf("buildGenericTypesFromSpec() = %v, want %v", gotTypes, tt.wantTypes) + } + }) + } +} + +func Test_buildGenericParamsString(t *testing.T) { + genTypes := genericTypes{{Names: []string{"A", "B"}, Type: "any"}} + genParams := genericParams{{Name: "string"}, {Name: "int"}} + + type args struct { + typeStr string + genericTypes genericTypes + genericParams genericParams + } + tests := []struct { + name string + args args + want string + }{ + { + name: "replace A by string", + args: args{ + typeStr: "A", + genericTypes: genTypes, + genericParams: genParams, + }, + want: "string", + }, + { + name: "replace B by int", + args: args{ + typeStr: "B", + genericTypes: genTypes, + genericParams: genParams, + }, + want: "int", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := buildGenericParamsString(tt.args.typeStr, tt.args.genericTypes, tt.args.genericParams); got != tt.want { + t.Errorf("buildGenericParamsString() = %v, want %v", got, tt.want) + } + }) + } +} From a95ed465d22b17a12665dfad6960097e7f209b70 Mon Sep 17 00:00:00 2001 From: lwfreitas Date: Sat, 18 Feb 2023 17:47:26 +0000 Subject: [PATCH 7/9] add two more sub tests in Test_processInterface --- generator/generator_test.go | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/generator/generator_test.go b/generator/generator_test.go index c7310a29..b6ab3467 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -377,12 +377,32 @@ func Test_processInterface(t *testing.T) { want1: methodsList{}, wantErr: false, }, + { + name: "index expression with identifier", + args: args{ + it: &ast.InterfaceType{Methods: &ast.FieldList{List: []*ast.Field{ + { + Type: &ast.IndexExpr{X: &ast.Ident{Name: "Embedded"}, Index: &ast.Ident{Name: "Embedded_2"}}, + }, + }}}, + targetInput: targetProcessInput{ + processInput: processInput{ + fileSet: token.NewFileSet(), + }, + types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Embedded"}, Type: &ast.InterfaceType{}}}, + }, + }, + want1: methodsList{}, + wantErr: false, + }, { name: "index list expression with identifier", args: args{ it: &ast.InterfaceType{Methods: &ast.FieldList{List: []*ast.Field{ { - Type: &ast.IndexListExpr{X: &ast.Ident{Name: "Embedded"}}, + Type: &ast.IndexListExpr{X: &ast.Ident{Name: "Embedded"}, Indices: []ast.Expr{ + &ast.Ident{Name: "Embedded_2"}, + }}, }, }}}, targetInput: targetProcessInput{ From 5544861f1726c29b6f25cb680e61a32d18a6f12f Mon Sep 17 00:00:00 2001 From: lwfreitas Date: Sat, 18 Feb 2023 22:26:48 +0000 Subject: [PATCH 8/9] Fix error and improve getEmbeddedMethods --- generator/generator.go | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/generator/generator.go b/generator/generator.go index 6f369280..428a7dcb 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -386,24 +386,17 @@ func typeSpecs(f *ast.File) []*ast.TypeSpec { } func getEmbeddedMethods(t ast.Expr, pr typePrinter, input targetProcessInput) (param genericParam, methods methodsList, err error) { + param.Name, err = pr.PrintType(t) + if err != nil { + return + } + switch v := t.(type) { case *ast.SelectorExpr: - if x, ok := v.X.(*ast.Ident); ok && x != nil { - param.Name, err = pr.PrintType(x) - if err != nil { - return - } - } - methods, err = processSelector(v, input) return case *ast.Ident: - param.Name, err = pr.PrintType(v) - if err != nil { - return - } - methods, err = processIdent(v, input) return } From 3bd0e554c7d2bcab445c916868cafa83ff275075 Mon Sep 17 00:00:00 2001 From: lwfreitas Date: Sun, 19 Feb 2023 01:18:34 +0000 Subject: [PATCH 9/9] Fix lint error --- generator/generics_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generator/generics_test.go b/generator/generics_test.go index 0c8973df..e10c8035 100644 --- a/generator/generics_test.go +++ b/generator/generics_test.go @@ -87,7 +87,7 @@ func Test_buildGenericTypesFromSpec(t *testing.T) { wantTypes genericTypes }{ { - name: "", + name: "build generic types any from spec", args: args{ ts: &ast.TypeSpec{ TypeParams: &ast.FieldList{ @@ -111,7 +111,7 @@ func Test_buildGenericTypesFromSpec(t *testing.T) { }, wantTypes: genericTypes{ { - Type: "any", + Type: "any", Names: []string{"I", "O"}, }, },