Skip to content

Commit

Permalink
Merge pull request #160 from butuzov/refactor
Browse files Browse the repository at this point in the history
refactoring: code refactoring
  • Loading branch information
cweill authored May 23, 2021
2 parents 90387e1 + c76fad4 commit 16a93f6
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 201 deletions.
6 changes: 4 additions & 2 deletions gotests.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func generateTest(src models.Path, files []models.Path, opt *Options) (*Generate
return nil, nil
}

b, err := output.Process(h, funcs, &output.Options{
options := output.Options{
PrintInputs: opt.PrintInputs,
Subtests: opt.Subtests,
Parallel: opt.Parallel,
Expand All @@ -129,7 +129,9 @@ func generateTest(src models.Path, files []models.Path, opt *Options) (*Generate
TemplateDir: opt.TemplateDir,
TemplateParams: opt.TemplateParams,
TemplateData: opt.TemplateData,
})
}

b, err := options.Process(h, funcs)
if err != nil {
return nil, fmt.Errorf("output.Process: %v", err)
}
Expand Down
5 changes: 2 additions & 3 deletions gotests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,6 @@ func TestGenerateTests(t *testing.T) {
{
name: "Receiver is indirect imported struct",
args: args{
// only: regexp.MustCompile("^Foo037$"),
srcPath: `testdata/test037.go`,
},
want: mustReadAndFormatGoFile(t, "testdata/goldens/receiver_is_indirect_imported_struct.go"),
Expand Down Expand Up @@ -559,7 +558,7 @@ func TestGenerateTests(t *testing.T) {
},
want: mustReadAndFormatGoFile(t, "testdata/goldens/existing_test_file_with_multiple_imports.go"),
},
{
{ // WORNING: data race condition, if called with -race flag, because of structure in `internal/templates` package.

This comment has been minimized.

Copy link
@StevenACoffman

StevenACoffman Jun 9, 2021

nit: WORNING -> WARNING

name: "Entire testdata directory",
args: args{
srcPath: `testdata/`,
Expand Down Expand Up @@ -713,7 +712,7 @@ func TestGenerateTests(t *testing.T) {
},
want: mustReadAndFormatGoFile(t, "testdata/goldens/function_with_return_value_custom_template.go"),
},
{
{ // WORNING: panics on -race flag.
name: "Test interface embedding",
args: args{
srcPath: `testdata/undefinedtypes/interface_embedding.go`,
Expand Down
8 changes: 8 additions & 0 deletions internal/output/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package output

import "os"

func IsFileExist(path string) bool {
_, err := os.Stat(path)
return !os.IsNotExist(err)
}
6 changes: 6 additions & 0 deletions internal/output/imports.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package output

// we do not need support for aliases in import for now.
var importsMap = map[string]string{
"testify": "github.com/stretchr/testify/assert",
}
99 changes: 99 additions & 0 deletions internal/output/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package output

import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"os"

"github.com/cweill/gotests/internal/models"
"github.com/cweill/gotests/internal/render"
"golang.org/x/tools/imports"
)

type Options struct {
PrintInputs bool
Subtests bool
Parallel bool
Named bool
Template string
TemplateDir string
TemplateParams map[string]interface{}
TemplateData [][]byte

render *render.Render
}

func (o *Options) Process(head *models.Header, funcs []*models.Function) ([]byte, error) {
o.render = render.New()

switch {
case o.providesTemplateDir():
if err := o.render.LoadCustomTemplates(o.TemplateDir); err != nil {
return nil, fmt.Errorf("loading custom templates: %v", err)
}
case o.providesTemplate():
if err := o.render.LoadCustomTemplatesName(o.Template); err != nil {
return nil, fmt.Errorf("loading custom templates of name: %v", err)
}
case o.providesTemplateData():
o.render.LoadFromData(o.TemplateData)
}

//
tf, err := ioutil.TempFile("", "gotests_")
if err != nil {
return nil, fmt.Errorf("ioutil.TempFile: %v", err)
}
defer tf.Close()
defer os.Remove(tf.Name())

// create physical copy of test
b := &bytes.Buffer{}
if err := o.writeTests(b, head, funcs); err != nil {
return nil, err
}

// format file
out, err := imports.Process(tf.Name(), b.Bytes(), nil)
if err != nil {
return nil, fmt.Errorf("imports.Process: %v", err)
}
return out, nil
}

func (o *Options) providesTemplateData() bool {
return o != nil && len(o.TemplateData) > 0
}

func (o *Options) providesTemplateDir() bool {
return o != nil && o.TemplateDir != ""
}

func (o *Options) providesTemplate() bool {
return o != nil && o.Template != ""
}

func (o *Options) writeTests(w io.Writer, head *models.Header, funcs []*models.Function) error {
if path, ok := importsMap[o.Template]; ok {
head.Imports = append(head.Imports, &models.Import{
Path: fmt.Sprintf(`"%s"`, path),
})
}

b := bufio.NewWriter(w)
if err := o.render.Header(b, head); err != nil {
return fmt.Errorf("render.Header: %v", err)
}

for _, fun := range funcs {
err := o.render.TestFunction(b, fun, o.PrintInputs, o.Subtests, o.Named, o.Parallel, o.TemplateParams)
if err != nil {
return fmt.Errorf("render.TestFunction: %v", err)
}
}

return b.Flush()
}
File renamed without changes.
94 changes: 0 additions & 94 deletions internal/output/output.go

This file was deleted.

Empty file modified internal/render/README.md
100644 → 100755
Empty file.
73 changes: 73 additions & 0 deletions internal/render/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package render

//go:generate esc -o bindata/esc.go -pkg=bindata templates
import (
"fmt"
"strings"

"github.com/cweill/gotests/internal/models"
)

const nFile = 7 // Number of files to be read from template (package) template (directory)

func fieldName(f *models.Field) string {
var n string
if f.IsNamed() {
n = f.Name
} else {
n = f.Type.String()
}
return n
}

func receiverName(f *models.Receiver) string {
var n string
if f.IsNamed() {
n = f.Name
} else {
n = f.ShortName()
}
if n == "name" {
// Avoid conflict with test struct's "name" field.
n = "n"
} else if n == "t" {
// Avoid conflict with test argument.
// "tr" is short for t receiver.
n = "tr"
}
return n
}

func parameterName(f *models.Field) string {
var n string
if f.IsNamed() {
n = f.Name
} else {
n = fmt.Sprintf("in%v", f.Index)
}
return n
}

func wantName(f *models.Field) string {
var n string
if f.IsNamed() {
n = "want" + strings.Title(f.Name)
} else if f.Index == 0 {
n = "want"
} else {
n = fmt.Sprintf("want%v", f.Index)
}
return n
}

func gotName(f *models.Field) string {
var n string
if f.IsNamed() {
n = "got" + strings.Title(f.Name)
} else if f.Index == 0 {
n = "got"
} else {
n = fmt.Sprintf("got%v", f.Index)
}
return n
}
Loading

0 comments on commit 16a93f6

Please sign in to comment.