diff --git a/cmd/sqlc-restruct/separate_interface.go b/cmd/sqlc-restruct/separate_interface.go index efede67..e8ceca9 100644 --- a/cmd/sqlc-restruct/separate_interface.go +++ b/cmd/sqlc-restruct/separate_interface.go @@ -12,19 +12,31 @@ var SeparateInterfaceCommand = &cli.Command{ Flags: []cli.Flag{ &cli.StringFlag{ Name: "iface-pkg-name", - Usage: "The package name where the separated models and Querier will be located.", + Usage: "The package name where the separated Querier will be located.", Required: true, }, &cli.StringFlag{ Name: "iface-pkg-url", - Usage: "The package URL where the separated models and Querier will be located (e.g. \"github.com///path/to/pkg\").", + Usage: "The package URL where the separated Querier will be located. (e.g. \"github.com///path/to/pkg\")", Required: true, }, &cli.StringFlag{ Name: "iface-dir", - Usage: "The directory path where the separated models and Querier will be located.", + Usage: "The directory path where the separated Querier will be located.", Required: true, }, + &cli.StringFlag{ + Name: "models-pkg-name", + Usage: "The package name where the separated models will be located. (default: --models-pkg-name value)", + }, + &cli.StringFlag{ + Name: "models-pkg-url", + Usage: "The package URL where the separated models will be located. (default: --models-pkg-url value)", + }, + &cli.StringFlag{ + Name: "models-dir", + Usage: "The directory path where the separated models will be located. (default: --iface-dir value)", + }, &cli.StringFlag{ Name: "impl-dir", Usage: "The original directory where the sqlc-generated code is located.", @@ -47,10 +59,30 @@ var SeparateInterfaceCommand = &cli.Command{ }, }, Action: func(c *cli.Context) error { + iPkgName := c.String("iface-pkg-name") + iPkgURL := c.String("iface-pkg-url") + iDir := c.String("iface-dir") + + mPkgName := c.String("models-pkg-name") + if mPkgName == "" { + mPkgName = iPkgName + } + mPkgURL := c.String("models-pkg-url") + if mPkgURL == "" { + mPkgURL = iPkgURL + } + mDir := c.String("models-dir") + if mDir == "" { + mDir = iDir + } + return separateinterface.Action(c.Context, separateinterface.ActionInput{ - IfacePkgName: c.String("iface-pkg-name"), - IfacePkgURL: c.String("iface-pkg-url"), - IfaceDir: c.String("iface-dir"), + IfacePkgName: iPkgName, + IfacePkgURL: iPkgURL, + IfaceDir: iDir, + ModelsPkgName: mPkgName, + ModelsPkgURL: mPkgURL, + ModelsDir: mDir, ImplDir: c.String("impl-dir"), ImplSQLSuffix: c.String("impl-sql-suffix"), ModelsFileName: c.String("models-file-name"), diff --git a/example/domain/models/.gitignore b/example/domain/models/.gitignore new file mode 100644 index 0000000..120f485 --- /dev/null +++ b/example/domain/models/.gitignore @@ -0,0 +1,2 @@ +* +!/.gitignore diff --git a/example/infra/db/db.go b/example/infra/db/db.go index 016ac7e..f2d8c81 100644 --- a/example/infra/db/db.go +++ b/example/infra/db/db.go @@ -1,4 +1,4 @@ package db //go:generate docker compose run --rm -T sqlc generate -//go:generate sqlc-restruct separate-interface --models-file-name=models.gen.go --querier-file-name=querier.gen.go --iface-dir=../../domain/repos --iface-pkg-name=repos --iface-pkg-url=github.com/mpyw/sqlc-restruct/example/domain/repos +//go:generate sqlc-restruct separate-interface --models-file-name=models.gen.go --querier-file-name=querier.gen.go --iface-dir=../../domain/repos --iface-pkg-name=repos --iface-pkg-url=github.com/mpyw/sqlc-restruct/example/domain/repos --models-dir=../../domain/models --models-pkg-name=models --models-pkg-url=github.com/mpyw/sqlc-restruct/example/domain/models diff --git a/example/migrations/20230101000000-create_users_table.sql b/example/migrations/20230101000000-create_users_table.sql index e565969..edd9212 100644 --- a/example/migrations/20230101000000-create_users_table.sql +++ b/example/migrations/20230101000000-create_users_table.sql @@ -1,8 +1,10 @@ -- +migrate Up +CREATE TYPE user_status AS ENUM ('active', 'inactive'); CREATE TABLE users( id uuid PRIMARY KEY DEFAULT gen_random_uuid(), email text NOT NULL UNIQUE, name text NOT NULL, + status user_status NOT NULL default 'active', created_at timestamptz NOT NULL DEFAULT current_timestamp, updated_at timestamptz NOT NULL DEFAULT current_timestamp ); @@ -23,3 +25,4 @@ COMMENT ON COLUMN users.name IS 'Name'; -- +migrate Down DROP TABLE users; +DROP TYPE user_statuses; diff --git a/go.mod b/go.mod index 57999c3..62c48a6 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.20 require ( github.com/jackc/pgx/v5 v5.4.1 github.com/urfave/cli/v2 v2.25.7 + golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df ) require ( diff --git a/go.sum b/go.sum index 87038c0..f19466d 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRT github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df h1:UA2aFVmmsIlefxMk29Dp2juaUSth8Pyn3Tq5Y5mJGME= +golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/actions/separate-interface/internal/astutil/astutil.go b/pkg/actions/separate-interface/internal/astutil/astutil.go index 28cd426..9e86018 100644 --- a/pkg/actions/separate-interface/internal/astutil/astutil.go +++ b/pkg/actions/separate-interface/internal/astutil/astutil.go @@ -51,6 +51,26 @@ func ExtractImportDecls(decls ...ast.Decl) []ast.Decl { }} } +func SymbolNameFromTypeOrValueDecls(decls ...ast.Decl) []string { + var symbols []string + for _, decl := range decls { + switch decl := decl.(type) { + case *ast.GenDecl: + for _, spec := range decl.Specs { + switch spec := spec.(type) { + case *ast.ValueSpec: + for _, name := range spec.Names { + symbols = append(symbols, name.Name) + } + case *ast.TypeSpec: + symbols = append(symbols, spec.Name.Name) + } + } + } + } + return symbols +} + func individualSpecs(exp bool, specs ...ast.Spec) []ast.Spec { var exported []ast.Spec for _, spec := range specs { @@ -140,6 +160,19 @@ func (r *ExportedExprIdentUpdater) Visit(n ast.Node) ast.Visitor { n.Rhs[i] = rh } } + case *ast.InterfaceType: + ast.Inspect(n, func(n ast.Node) bool { + switch n := n.(type) { + case *ast.Field: + if _, isInterfaceMethod := n.Type.(*ast.FuncType); !isInterfaceMethod { + if expr := r.resolveExpr(n.Type); expr != nil { + n.Type = expr + } + return false + } + } + return true + }) } return r } diff --git a/pkg/actions/separate-interface/runner.go b/pkg/actions/separate-interface/runner.go index 9d8c2c3..b03d749 100644 --- a/pkg/actions/separate-interface/runner.go +++ b/pkg/actions/separate-interface/runner.go @@ -13,11 +13,13 @@ import ( "strings" "github.com/mpyw/sqlc-restruct/pkg/actions/separate-interface/internal/astutil" + "golang.org/x/exp/slices" ) type runner struct { - input ActionInput - fset *token.FileSet + input ActionInput + fset *token.FileSet + exportedSymbolsInModels []string } func (r *runner) Run() error { @@ -25,6 +27,11 @@ func (r *runner) Run() error { if err != nil { return fmt.Errorf("runner.Run() failed: %w", err) } + f, err := parser.ParseFile(r.fset, path.Join(r.input.ImplDir, r.input.ModelsFileName), nil, parser.ParseComments) + if err != nil { + return fmt.Errorf("runner.Run() failed: %w", err) + } + r.exportedSymbolsInModels = astutil.SymbolNameFromTypeOrValueDecls(astutil.ExportedIndividualTypeOrValueDecls(f.Decls...)...) var newModelsContent []byte var newQuerierContent []byte @@ -58,8 +65,8 @@ func (r *runner) Run() error { } if newModelsContent != nil { - _ = os.Remove(path.Join(r.input.IfaceDir, r.input.ModelsFileName)) - if err := os.WriteFile(path.Join(r.input.IfaceDir, r.input.ModelsFileName), newModelsContent, 0644); err != nil { + _ = os.Remove(path.Join(r.input.ModelsDir, r.input.ModelsFileName)) + if err := os.WriteFile(path.Join(r.input.ModelsDir, r.input.ModelsFileName), newModelsContent, 0644); err != nil { return fmt.Errorf("runner.Run() failed: %w", err) } _ = os.Remove(path.Join(r.input.ImplDir, r.input.ModelsFileName)) @@ -96,7 +103,7 @@ func (r *runner) newModelsContent() ([]byte, error) { } // Change package name of "models" file - f.Name = ast.NewIdent(r.input.IfacePkgName) + f.Name = ast.NewIdent(r.input.ModelsPkgName) byt, err := r.intoBytes(f) if err != nil { @@ -114,6 +121,19 @@ func (r *runner) newQuerierContent() ([]byte, error) { // Change package name of "querier" file f.Name = ast.NewIdent(r.input.IfacePkgName) + // Prepend import statement of ModelsPkgURL + if r.input.ModelsPkgURL != r.input.IfacePkgURL { + f.Decls = append(append(([]ast.Decl)(nil), &ast.GenDecl{ + Tok: token.IMPORT, + Specs: []ast.Spec{&ast.ImportSpec{ + Path: &ast.BasicLit{ + Kind: token.STRING, + Value: fmt.Sprintf("%#v", r.input.ModelsPkgURL), + }, + }}, + }), f.Decls...) + } + // Remove top level constraint: var _ Querier = (*Querier)(nil) for i, decl := range f.Decls { if decl, ok := decl.(*ast.GenDecl); ok && decl.Tok == token.VAR { @@ -126,6 +146,22 @@ func (r *runner) newQuerierContent() ([]byte, error) { } } + // Qualify exported references + if r.input.ModelsPkgURL != r.input.IfacePkgURL { + ast.Walk( + astutil.NewExportedExprIdentUpdater(func(ident *ast.Ident) ast.Expr { + if slices.Contains(r.exportedSymbolsInModels, ident.Name) { + return &ast.SelectorExpr{ + X: ast.NewIdent(r.input.ModelsPkgName), + Sel: ident, + } + } + return nil + }), + f, + ) + } + dirEntries, err := os.ReadDir(r.input.ImplDir) if err != nil { return nil, fmt.Errorf("runner.newQuerierContent() failed: %w", err) @@ -172,11 +208,28 @@ func (r *runner) newQueriesContent(filename string) ([]byte, error) { }}, }), f.Decls...) + // Prepend import statement of ModelsPkgURL + if r.input.ModelsPkgURL != r.input.IfacePkgURL { + f.Decls = append(append(([]ast.Decl)(nil), &ast.GenDecl{ + Tok: token.IMPORT, + Specs: []ast.Spec{&ast.ImportSpec{ + Path: &ast.BasicLit{ + Kind: token.STRING, + Value: fmt.Sprintf("%#v", r.input.ModelsPkgURL), + }, + }}, + }), f.Decls...) + } + // Qualify exported references ast.Walk( astutil.NewExportedExprIdentUpdater(func(ident *ast.Ident) ast.Expr { + pkgName := r.input.IfacePkgName + if slices.Contains(r.exportedSymbolsInModels, ident.Name) { + pkgName = r.input.ModelsPkgName + } return &ast.SelectorExpr{ - X: ast.NewIdent(r.input.IfacePkgName), + X: ast.NewIdent(pkgName), Sel: ident, } }), diff --git a/pkg/actions/separate-interface/separate_interface.go b/pkg/actions/separate-interface/separate_interface.go index 3cb2bfb..9429f9b 100644 --- a/pkg/actions/separate-interface/separate_interface.go +++ b/pkg/actions/separate-interface/separate_interface.go @@ -8,12 +8,18 @@ import ( ) type ActionInput struct { - // IfacePkgName The package name where the separated models and `Querier` will be located. + // IfacePkgName The package name where the separated Querier will be located. IfacePkgName string - // IfacePkgURL The package URL where the separated models and `Querier` will be located (e.g. "github.com///path/to/pkg"). + // IfacePkgURL The package URL where the separated Querier will be located. IfacePkgURL string - // IfaceDir The directory path where the separated models and `Querier` will be located. + // IfaceDir The directory path where the separated Querier will be located. IfaceDir string + // ModelsPkgName The package name where the separated models will be located. + ModelsPkgName string + // ModelsPkgURL The package URL where the separated models will be located. + ModelsPkgURL string + // ModelsDir The directory path where the separated models will be located. + ModelsDir string // ImplDir The original directory where the sqlc-generated code is located. ImplDir string // ImplSQLSuffix The suffix for sqlc-generated files from SQL files.