diff --git a/README.md b/README.md index 7e845cb..f4ada07 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,8 @@ moq [flags] source-dir interface [interface2 [interface3 [...]]] -skip-ensure suppress mock implementation check, avoid import cycle if mocks generated outside of the tested package + -with-resets + generate functions to facilitate resetting calls made to a mock Specifying an alias for the mock is also supported with the format 'interface:alias' @@ -120,6 +122,7 @@ The mocked structure implements the interface, where each method calls the assoc * Name arguments in the interface for a better experience * Use closured variables inside your test function to capture details about the calls to the methods * Use `.MethodCalls()` to track the calls +* Use `.ResetCalls()` to reset calls within an invidual mock's context * Use `go:generate` to invoke the `moq` command * If Moq fails with a `go/format` error, it indicates the generated code was not valid. You can run the same command with `-fmt noop` to print the generated source code without attempting to format it. diff --git a/internal/template/template.go b/internal/template/template.go index 8412ab9..9e2672c 100644 --- a/internal/template/template.go +++ b/internal/template/template.go @@ -182,8 +182,27 @@ func (mock *{{$mock.MockName}} mock.lock{{.Name}}.RUnlock() return calls } +{{- if $.WithResets}} +// Reset{{.Name}}Calls reset all the calls that were made to {{.Name}}. +func (mock *{{$mock.MockName}}) Reset{{.Name}}Calls() { + mock.lock{{.Name}}.Lock() + mock.calls.{{.Name}} = nil + mock.lock{{.Name}}.Unlock() +} +{{end}} +{{end -}} +{{- if $.WithResets}} +// ResetCalls reset all the calls that were made to all mocked methods. +func (mock *{{$mock.MockName}}) ResetCalls() { + {{- range .Methods}} + mock.lock{{.Name}}.Lock() + mock.calls.{{.Name}} = nil + mock.lock{{.Name}}.Unlock() + {{end -}} +} +{{end -}} {{end -}} -{{end -}}` +` // This list comes from the golint codebase. Golint will complain about any of // these being mixed-case, like "Id" instead of "ID". diff --git a/internal/template/template_data.go b/internal/template/template_data.go index 12c9117..2a3caeb 100644 --- a/internal/template/template_data.go +++ b/internal/template/template_data.go @@ -16,6 +16,7 @@ type Data struct { Mocks []MockData StubImpl bool SkipEnsure bool + WithResets bool } // MocksSomeMethod returns true of any one of the Mocks has at least 1 diff --git a/main.go b/main.go index 37c3937..89adb3d 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,7 @@ type userFlags struct { formatter string stubImpl bool skipEnsure bool + withResets bool remove bool args []string } @@ -37,6 +38,8 @@ func main() { flag.BoolVar(&flags.skipEnsure, "skip-ensure", false, "suppress mock implementation check, avoid import cycle if mocks generated outside of the tested package") flag.BoolVar(&flags.remove, "rm", false, "first remove output file, if it exists") + flag.BoolVar(&flags.withResets, "with-resets", false, + "generate functions to facilitate resetting calls made to a mock") flag.Usage = func() { fmt.Println(`moq [flags] source-dir interface [interface2 [interface3 [...]]]`) @@ -86,6 +89,7 @@ func run(flags userFlags) error { Formatter: flags.formatter, StubImpl: flags.stubImpl, SkipEnsure: flags.skipEnsure, + WithResets: flags.withResets, }) if err != nil { return err @@ -100,10 +104,10 @@ func run(flags userFlags) error { } // create the file - err = os.MkdirAll(filepath.Dir(flags.outFile), 0750) + err = os.MkdirAll(filepath.Dir(flags.outFile), 0o750) if err != nil { return err } - return ioutil.WriteFile(flags.outFile, buf.Bytes(), 0600) + return ioutil.WriteFile(flags.outFile, buf.Bytes(), 0o600) } diff --git a/pkg/moq/moq.go b/pkg/moq/moq.go index a33a4cb..e8a2975 100644 --- a/pkg/moq/moq.go +++ b/pkg/moq/moq.go @@ -28,6 +28,7 @@ type Config struct { Formatter string StubImpl bool SkipEnsure bool + WithResets bool } // New makes a new Mocker for the specified package directory. @@ -81,6 +82,7 @@ func (m *Mocker) Mock(w io.Writer, namePairs ...string) error { Mocks: mocks, StubImpl: m.cfg.StubImpl, SkipEnsure: m.cfg.SkipEnsure, + WithResets: m.cfg.WithResets, } if data.MocksSomeMethod() { diff --git a/pkg/moq/moq_test.go b/pkg/moq/moq_test.go index 758e92c..7dceb8f 100644 --- a/pkg/moq/moq_test.go +++ b/pkg/moq/moq_test.go @@ -29,7 +29,7 @@ func TestMoq(t *testing.T) { } s := buf.String() // assertions of things that should be mentioned - var strs = []string{ + strs := []string{ "package example", "type PersonStoreMock struct", "CreateFunc func(ctx context.Context, person *Person, confirm bool) error", @@ -62,7 +62,7 @@ func TestMoqWithStaticCheck(t *testing.T) { } s := buf.String() // assertions of things that should be mentioned - var strs = []string{ + strs := []string{ "package example", "var _ PersonStore = &PersonStoreMock{}", "type PersonStoreMock struct", @@ -96,7 +96,7 @@ func TestMoqWithAlias(t *testing.T) { } s := buf.String() // assertions of things that should be mentioned - var strs = []string{ + strs := []string{ "package example", "type AnotherPersonStoreMock struct", "CreateFunc func(ctx context.Context, person *Person, confirm bool) error", @@ -129,7 +129,7 @@ func TestMoqExplicitPackage(t *testing.T) { } s := buf.String() // assertions of things that should be mentioned - var strs = []string{ + strs := []string{ "package different", "type PersonStoreMock struct", "CreateFunc func(ctx context.Context, person *example.Person, confirm bool) error", @@ -156,7 +156,7 @@ func TestMoqExplicitPackageWithStaticCheck(t *testing.T) { } s := buf.String() // assertions of things that should be mentioned - var strs = []string{ + strs := []string{ "package different", "var _ example.PersonStore = &PersonStoreMock{}", "type PersonStoreMock struct", @@ -184,7 +184,7 @@ func TestMoqSkipEnsure(t *testing.T) { } s := buf.String() // assertions of things that should be mentioned - var strs = []string{ + strs := []string{ "package different", "type PersonStoreMock struct", "CreateFunc func(ctx context.Context, person *example.Person, confirm bool) error", @@ -233,7 +233,7 @@ func TestVariadicArguments(t *testing.T) { } s := buf.String() // assertions of things that should be mentioned - var strs = []string{ + strs := []string{ "package variadic", "type GreeterMock struct", "GreetFunc func(ctx context.Context, names ...string) string", @@ -261,7 +261,7 @@ func TestNothingToReturn(t *testing.T) { t.Errorf("should not have return for items that have no return arguments") } // assertions of things that should be mentioned - var strs = []string{ + strs := []string{ "mock.ClearCacheFunc(id)", } for _, str := range strs { @@ -282,7 +282,7 @@ func TestImports(t *testing.T) { t.Errorf("m.Mock: %s", err) } s := buf.String() - var strs = []string{ + strs := []string{ ` "sync"`, ` "github.com/matryer/moq/pkg/moq/testpackages/imports/one"`, } @@ -395,6 +395,12 @@ func TestMockGolden(t *testing.T) { interfaces: []string{"Transient"}, goldenFile: filepath.Join("testpackages/transientimport", "transient_moq.golden.go"), }, + { + name: "WithResets", + cfg: Config{SrcDir: "testpackages/withresets", WithResets: true}, + interfaces: []string{"ResetStore"}, + goldenFile: filepath.Join("testpackages/withresets", "withresets_moq.golden.go"), + }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { @@ -449,10 +455,10 @@ func matchGoldenFile(goldenFile string, actual []byte) error { // To update golden files, run the following: // go test -v -run '^$' github.com/matryer/moq/pkg/moq -update if *update { - if err := os.MkdirAll(filepath.Dir(goldenFile), 0750); err != nil { + if err := os.MkdirAll(filepath.Dir(goldenFile), 0o750); err != nil { return fmt.Errorf("create dir: %s", err) } - if err := ioutil.WriteFile(goldenFile, actual, 0600); err != nil { + if err := ioutil.WriteFile(goldenFile, actual, 0o600); err != nil { return fmt.Errorf("write: %s", err) } @@ -495,7 +501,7 @@ func TestVendoredPackages(t *testing.T) { } s := buf.String() // assertions of things that should be mentioned - var strs = []string{ + strs := []string{ `"github.com/sudo-suhas/moq-test-pkgs/somerepo"`, } for _, str := range strs { @@ -520,7 +526,7 @@ func TestVendoredInterface(t *testing.T) { } s := buf.String() // assertions of things that should be mentioned - var strs = []string{ + strs := []string{ `"github.com/sudo-suhas/moq-test-pkgs/somerepo"`, } for _, str := range strs { @@ -546,7 +552,7 @@ func TestVendoredBuildConstraints(t *testing.T) { } s := buf.String() // assertions of things that should be mentioned - var strs = []string{ + strs := []string{ `"github.com/sudo-suhas/moq-test-pkgs/buildconstraints"`, } for _, str := range strs { diff --git a/pkg/moq/testpackages/dotimport/service_moq_test.go b/pkg/moq/testpackages/dotimport/service_moq_test.go index d2855c5..b047039 100755 --- a/pkg/moq/testpackages/dotimport/service_moq_test.go +++ b/pkg/moq/testpackages/dotimport/service_moq_test.go @@ -14,19 +14,19 @@ var _ dotimport.Service = &ServiceMock{} // ServiceMock is a mock implementation of dotimport.Service. // -// func TestSomethingThatUsesService(t *testing.T) { +// func TestSomethingThatUsesService(t *testing.T) { // -// // make and configure a mocked dotimport.Service -// mockedService := &ServiceMock{ -// UserFunc: func(ID string) (dotimport.User, error) { -// panic("mock out the User method") -// }, -// } +// // make and configure a mocked dotimport.Service +// mockedService := &ServiceMock{ +// UserFunc: func(ID string) (dotimport.User, error) { +// panic("mock out the User method") +// }, +// } // -// // use mockedService in code that requires dotimport.Service -// // and then make assertions. +// // use mockedService in code that requires dotimport.Service +// // and then make assertions. // -// } +// } type ServiceMock struct { // UserFunc mocks the User method. UserFunc func(ID string) (dotimport.User, error) @@ -60,7 +60,8 @@ func (mock *ServiceMock) User(ID string) (dotimport.User, error) { // UserCalls gets all the calls that were made to User. // Check the length with: -// len(mockedService.UserCalls()) +// +// len(mockedService.UserCalls()) func (mock *ServiceMock) UserCalls() []struct { ID string } { diff --git a/pkg/moq/testpackages/withresets/withresets.go b/pkg/moq/testpackages/withresets/withresets.go new file mode 100644 index 0000000..8a87bff --- /dev/null +++ b/pkg/moq/testpackages/withresets/withresets.go @@ -0,0 +1,18 @@ +package withresets + +import "context" + +// Reset is a reset. +type Reset struct { + ID string + Name string + Company string + Website string +} + +// ResetStore stores resets. +type ResetStore interface { + Get(ctx context.Context, id string) (*Reset, error) + Create(ctx context.Context, person *Reset, confirm bool) error + ClearCache(id string) +} diff --git a/pkg/moq/testpackages/withresets/withresets_moq.golden.go b/pkg/moq/testpackages/withresets/withresets_moq.golden.go new file mode 100644 index 0000000..db08d32 --- /dev/null +++ b/pkg/moq/testpackages/withresets/withresets_moq.golden.go @@ -0,0 +1,217 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package withresets + +import ( + "context" + "sync" +) + +// Ensure, that ResetStoreMock does implement ResetStore. +// If this is not the case, regenerate this file with moq. +var _ ResetStore = &ResetStoreMock{} + +// ResetStoreMock is a mock implementation of ResetStore. +// +// func TestSomethingThatUsesResetStore(t *testing.T) { +// +// // make and configure a mocked ResetStore +// mockedResetStore := &ResetStoreMock{ +// ClearCacheFunc: func(id string) { +// panic("mock out the ClearCache method") +// }, +// CreateFunc: func(ctx context.Context, person *Reset, confirm bool) error { +// panic("mock out the Create method") +// }, +// GetFunc: func(ctx context.Context, id string) (*Reset, error) { +// panic("mock out the Get method") +// }, +// } +// +// // use mockedResetStore in code that requires ResetStore +// // and then make assertions. +// +// } +type ResetStoreMock struct { + // ClearCacheFunc mocks the ClearCache method. + ClearCacheFunc func(id string) + + // CreateFunc mocks the Create method. + CreateFunc func(ctx context.Context, person *Reset, confirm bool) error + + // GetFunc mocks the Get method. + GetFunc func(ctx context.Context, id string) (*Reset, error) + + // calls tracks calls to the methods. + calls struct { + // ClearCache holds details about calls to the ClearCache method. + ClearCache []struct { + // ID is the id argument value. + ID string + } + // Create holds details about calls to the Create method. + Create []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Person is the person argument value. + Person *Reset + // Confirm is the confirm argument value. + Confirm bool + } + // Get holds details about calls to the Get method. + Get []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // ID is the id argument value. + ID string + } + } + lockClearCache sync.RWMutex + lockCreate sync.RWMutex + lockGet sync.RWMutex +} + +// ClearCache calls ClearCacheFunc. +func (mock *ResetStoreMock) ClearCache(id string) { + if mock.ClearCacheFunc == nil { + panic("ResetStoreMock.ClearCacheFunc: method is nil but ResetStore.ClearCache was just called") + } + callInfo := struct { + ID string + }{ + ID: id, + } + mock.lockClearCache.Lock() + mock.calls.ClearCache = append(mock.calls.ClearCache, callInfo) + mock.lockClearCache.Unlock() + mock.ClearCacheFunc(id) +} + +// ClearCacheCalls gets all the calls that were made to ClearCache. +// Check the length with: +// +// len(mockedResetStore.ClearCacheCalls()) +func (mock *ResetStoreMock) ClearCacheCalls() []struct { + ID string +} { + var calls []struct { + ID string + } + mock.lockClearCache.RLock() + calls = mock.calls.ClearCache + mock.lockClearCache.RUnlock() + return calls +} + +// ResetClearCacheCalls reset all the calls that were made to ClearCache. +func (mock *ResetStoreMock) ResetClearCacheCalls() { + mock.lockClearCache.Lock() + mock.calls.ClearCache = nil + mock.lockClearCache.Unlock() +} + +// Create calls CreateFunc. +func (mock *ResetStoreMock) Create(ctx context.Context, person *Reset, confirm bool) error { + if mock.CreateFunc == nil { + panic("ResetStoreMock.CreateFunc: method is nil but ResetStore.Create was just called") + } + callInfo := struct { + Ctx context.Context + Person *Reset + Confirm bool + }{ + Ctx: ctx, + Person: person, + Confirm: confirm, + } + mock.lockCreate.Lock() + mock.calls.Create = append(mock.calls.Create, callInfo) + mock.lockCreate.Unlock() + return mock.CreateFunc(ctx, person, confirm) +} + +// CreateCalls gets all the calls that were made to Create. +// Check the length with: +// +// len(mockedResetStore.CreateCalls()) +func (mock *ResetStoreMock) CreateCalls() []struct { + Ctx context.Context + Person *Reset + Confirm bool +} { + var calls []struct { + Ctx context.Context + Person *Reset + Confirm bool + } + mock.lockCreate.RLock() + calls = mock.calls.Create + mock.lockCreate.RUnlock() + return calls +} + +// ResetCreateCalls reset all the calls that were made to Create. +func (mock *ResetStoreMock) ResetCreateCalls() { + mock.lockCreate.Lock() + mock.calls.Create = nil + mock.lockCreate.Unlock() +} + +// Get calls GetFunc. +func (mock *ResetStoreMock) Get(ctx context.Context, id string) (*Reset, error) { + if mock.GetFunc == nil { + panic("ResetStoreMock.GetFunc: method is nil but ResetStore.Get was just called") + } + callInfo := struct { + Ctx context.Context + ID string + }{ + Ctx: ctx, + ID: id, + } + mock.lockGet.Lock() + mock.calls.Get = append(mock.calls.Get, callInfo) + mock.lockGet.Unlock() + return mock.GetFunc(ctx, id) +} + +// GetCalls gets all the calls that were made to Get. +// Check the length with: +// +// len(mockedResetStore.GetCalls()) +func (mock *ResetStoreMock) GetCalls() []struct { + Ctx context.Context + ID string +} { + var calls []struct { + Ctx context.Context + ID string + } + mock.lockGet.RLock() + calls = mock.calls.Get + mock.lockGet.RUnlock() + return calls +} + +// ResetGetCalls reset all the calls that were made to Get. +func (mock *ResetStoreMock) ResetGetCalls() { + mock.lockGet.Lock() + mock.calls.Get = nil + mock.lockGet.Unlock() +} + +// ResetCalls reset all the calls that were made to all mocked methods. +func (mock *ResetStoreMock) ResetCalls() { + mock.lockClearCache.Lock() + mock.calls.ClearCache = nil + mock.lockClearCache.Unlock() + + mock.lockCreate.Lock() + mock.calls.Create = nil + mock.lockCreate.Unlock() + + mock.lockGet.Lock() + mock.calls.Get = nil + mock.lockGet.Unlock() +}