Skip to content

Commit

Permalink
Merge pull request #60 from NoGambiNoBugs/feature/generics
Browse files Browse the repository at this point in the history
Feature Generics
  • Loading branch information
hexdigest authored Feb 19, 2023
2 parents 3c184ed + 3bd0e55 commit 67a1cd8
Show file tree
Hide file tree
Showing 5 changed files with 581 additions and 136 deletions.
232 changes: 177 additions & 55 deletions generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ type Generator struct {
dstPackage *packages.Package
methods methodsList
interfaceType string
genericTypes string
genericParams string
localPrefix string
}

Expand Down Expand Up @@ -85,12 +87,12 @@ type TemplateInputInterface struct {
Name string
// Type of the interface, with package name qualifier (e.g. sort.Interface)
Type string
// Generics of the interface when using generics
Generics TemplateInputGenerics
// Methods name keyed map of method information
Methods map[string]Method
}

type methodsList map[string]Method

// Options of the NewGenerator constructor
type Options struct {
//InterfaceName is a name of interface type
Expand Down Expand Up @@ -128,6 +130,30 @@ type Options struct {
LocalPrefix string
}

type methodsList map[string]Method

type processInput struct {
fileSet *token.FileSet
currentPackage *packages.Package
astPackage *ast.Package
targetName string
genericParams genericParams
}

type targetProcessInput struct {
processInput
types []*ast.TypeSpec
typesPrefix string
imports []*ast.ImportSpec
genericTypes genericTypes
}

type processOutput struct {
genericTypes genericTypes
methods methodsList
imports []*ast.ImportSpec
}

var errEmptyInterface = errors.New("interface has no methods")
var errUnexportedMethod = errors.New("unexported method")

Expand Down Expand Up @@ -185,22 +211,29 @@ func NewGenerator(options Options) (*Generator, error) {
options.Imports = append(options.Imports, `"`+srcPackage.PkgPath+`"`)
}

methods, imports, err := findInterface(fs, srcPackage, srcPackageAST, options.InterfaceName)
output, err := findTarget(processInput{
fileSet: fs,
currentPackage: srcPackage,
astPackage: srcPackageAST,
targetName: options.InterfaceName,
})
if err != nil {
return nil, errors.Wrap(err, "failed to parse interface declaration")
}

if len(methods) == 0 {
if len(output.methods) == 0 {
return nil, errEmptyInterface
}

for _, m := range methods {
for _, m := range output.methods {
if srcPackageAST.Name != "" && []rune(m.Name)[0] == []rune(strings.ToLower(m.Name))[0] {
return nil, errors.Wrap(errUnexportedMethod, m.Name)
}
}

options.Imports = append(options.Imports, makeImports(imports)...)
options.Imports = append(options.Imports, makeImports(output.imports)...)

genericTypes, genericParams := output.genericTypes.buildVars()

return &Generator{
Options: options,
Expand All @@ -209,7 +242,9 @@ func NewGenerator(options Options) (*Generator, error) {
srcPackage: srcPackage,
dstPackage: dstPackage,
interfaceType: interfaceType,
methods: methods,
genericTypes: genericTypes,
genericParams: genericParams,
methods: output.methods,
localPrefix: options.LocalPrefix,
}, nil
}
Expand Down Expand Up @@ -266,7 +301,11 @@ func (g Generator) Generate(w io.Writer) error {

err = g.bodyTemplate.Execute(buf, TemplateInputs{
Interface: TemplateInputInterface{
Name: g.Options.InterfaceName,
Name: g.Options.InterfaceName,
Generics: TemplateInputGenerics{
Types: g.genericTypes,
Params: g.genericParams,
},
Type: g.interfaceType,
Methods: g.methods,
},
Expand All @@ -287,43 +326,47 @@ func (g Generator) Generate(w io.Writer) error {
return err
}

var errInterfaceNotFound = errors.New("interface type declaration not found")
var errTargetNotFound = errors.New("target declaration not found")

// findInterface looks for the interface declaration in the given directory
// and returns a list of the interface's methods and a list of imports from the file
// where interface type declaration was found
func findInterface(fs *token.FileSet, currentPackage *packages.Package, p *ast.Package, interfaceName string) (methods methodsList, imports []*ast.ImportSpec, err error) {
var found bool
var types []*ast.TypeSpec
var it *ast.InterfaceType
func findTarget(input processInput) (output processOutput, err error) {
ts, imports, types := iterateFiles(input.astPackage, input.targetName)
if ts == nil {
return processOutput{}, errors.Wrap(errTargetNotFound, input.targetName)
}

//looking for the source interface declaration in all files in the dir
//while doing this we also store all found type declarations to check if some of the
//interface methods use unexported types
for _, f := range p.Files {
for _, ts := range typeSpecs(f) {
types = append(types, ts)
output.imports = imports
output.genericTypes = buildGenericTypesFromSpec(ts)

if i, ok := ts.Type.(*ast.InterfaceType); ok {
if ts.Name.Name == interfaceName && !found {
imports = f.Imports
it = i
found = true
}
}
if it, ok := ts.Type.(*ast.InterfaceType); ok {
output.methods, err = processInterface(it, targetProcessInput{
processInput: input,
types: types,
typesPrefix: input.astPackage.Name,
imports: output.imports,
genericTypes: output.genericTypes,
})
if err != nil {
return processOutput{}, err
}
}

if !found {
return nil, nil, errors.Wrap(errInterfaceNotFound, interfaceName)
}
return
}

methods, err = processInterface(fs, currentPackage, it, types, p.Name, imports)
if err != nil {
return nil, nil, err
func iterateFiles(p *ast.Package, name string) (selectedType *ast.TypeSpec, imports []*ast.ImportSpec, types []*ast.TypeSpec) {
for _, f := range p.Files {
if f != nil {
for _, ts := range typeSpecs(f) {
types = append(types, ts)
if ts.Name.Name == name {
selectedType = ts
imports = f.Imports
return
}
}
}
}

return methods, imports, err
return
}

func typeSpecs(f *ast.File) []*ast.TypeSpec {
Expand All @@ -342,29 +385,99 @@ func typeSpecs(f *ast.File) []*ast.TypeSpec {
return result
}

func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *ast.InterfaceType, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methods methodsList, err error) {
func getEmbeddedMethods(t ast.Expr, pr typePrinter, input targetProcessInput) (param genericParam, methods methodsList, err error) {
param.Name, err = pr.PrintType(t)
if err != nil {
return
}

switch v := t.(type) {
case *ast.SelectorExpr:
methods, err = processSelector(v, input)
return

case *ast.Ident:
methods, err = processIdent(v, input)
return
}
return
}

func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput) (genericParam genericParam, embeddedMethods methodsList, err error) {
var x ast.Expr
var hasGenericsParams bool
var genericParams genericParams

switch v := t.(type) {
case *ast.IndexExpr:
x = v.X
hasGenericsParams = true

genericParam, _, err = processEmbedded(v.Index, pr, input)
if err != nil {
return
}
if genericParam.Name != "" {
genericParams = append(genericParams, genericParam)
}

case *ast.IndexListExpr:
x = v.X
hasGenericsParams = true

if v.Indices != nil {
for _, index := range v.Indices {
genericParam, _, err = processEmbedded(index, pr, input)
if err != nil {
return
}
if genericParam.Name != "" {
genericParams = append(genericParams, genericParam)
}
}
}
default:
x = v
}

input.genericParams = genericParams
genericParam, embeddedMethods, err = getEmbeddedMethods(x, pr, input)
if err != nil {
return
}

if hasGenericsParams {
genericParam.Params = genericParams
}

return
}

func processInterface(it *ast.InterfaceType, targetInput targetProcessInput) (methods methodsList, err error) {
if it.Methods == nil {
return nil, nil
}

methods = make(methodsList, len(it.Methods.List))

pr := printer.New(targetInput.fileSet, targetInput.types, targetInput.typesPrefix)

for _, field := range it.Methods.List {
var embeddedMethods methodsList
var err error

switch v := field.Type.(type) {
case *ast.FuncType:
var method *Method
method, err = NewMethod(field.Names[0].Name, field, printer.New(fs, types, typesPrefix))

method, err = NewMethod(field.Names[0].Name, field, pr, targetInput.genericTypes, targetInput.genericParams)
if err == nil {
methods[field.Names[0].Name] = *method
continue
}
case *ast.SelectorExpr:
embeddedMethods, err = processSelector(fs, currentPackage, v, imports)
case *ast.Ident:
embeddedMethods, err = processIdent(fs, currentPackage, v, types, typesPrefix, imports)

default:
_, embeddedMethods, err = processEmbedded(v, pr, targetInput)
}

if err != nil {
Expand All @@ -380,28 +493,34 @@ func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *a
return methods, nil
}

func processSelector(fs *token.FileSet, currentPackage *packages.Package, se *ast.SelectorExpr, imports []*ast.ImportSpec) (methodsList, error) {
interfaceName := se.Sel.Name
func processSelector(se *ast.SelectorExpr, input targetProcessInput) (methodsList, error) {
selectedName := se.Sel.Name
packageSelector := se.X.(*ast.Ident).Name

importPath, err := findImportPathForName(packageSelector, imports, currentPackage)
importPath, err := findImportPathForName(packageSelector, input.imports, input.currentPackage)
if err != nil {
return nil, errors.Wrapf(err, "unable to find package %s", packageSelector)
}

p, ok := currentPackage.Imports[importPath]
p, ok := input.currentPackage.Imports[importPath]
if !ok {
return nil, fmt.Errorf("unable to find package %s", packageSelector)
}

astPkg, err := pkg.AST(fs, p)
astPkg, err := pkg.AST(input.fileSet, p)
if err != nil {
return nil, errors.Wrap(err, "failed to import package")
}

methods, _, err := findInterface(fs, p, astPkg, interfaceName)
output, err := findTarget(processInput{
fileSet: input.fileSet,
currentPackage: p,
astPackage: astPkg,
targetName: selectedName,
genericParams: input.genericParams,
})

return methods, err
return output.methods, err
}

// mergeMethods merges two methods list. Retains overlapping methods from the
Expand All @@ -423,27 +542,30 @@ func mergeMethods(methods, embeddedMethods methodsList) (methodsList, error) {
return result, nil
}

var errEmbeddedInterfaceNotFound = errors.New("embedded interface not found")
var errNotAnInterface = errors.New("embedded type is not an interface")

func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Ident, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methodsList, error) {
func processIdent(i *ast.Ident, input targetProcessInput) (methodsList, error) {
var embeddedInterface *ast.InterfaceType
for _, t := range types {
var genericsTypes genericTypes
for _, t := range input.types {
if t.Name.Name == i.Name {
var ok bool
embeddedInterface, ok = t.Type.(*ast.InterfaceType)
if !ok {
return nil, errors.Wrap(errNotAnInterface, t.Name.Name)
}

genericsTypes = buildGenericTypesFromSpec(t)
break
}
}

if embeddedInterface == nil {
return nil, errors.Wrap(errEmbeddedInterfaceNotFound, i.Name)
return nil, nil
}

return processInterface(fs, currentPackage, embeddedInterface, types, typesPrefix, imports)
input.genericTypes = genericsTypes
return processInterface(embeddedInterface, input)
}

var errUnknownSelector = errors.New("unknown selector")
Expand Down
Loading

0 comments on commit 67a1cd8

Please sign in to comment.