Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use generated mocks in flyteadmin #6197

Merged
merged 4 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions flyteadmin/dataproxy/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
Expand All @@ -33,8 +34,8 @@ func TestNewService(t *testing.T) {
dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
s, err := NewService(config.DataProxyConfig{
Upload: config.DataProxyUploadConfig{},
}, nodeExecutionManager, dataStore, taskExecutionManager)
Expand All @@ -59,8 +60,8 @@ func Test_createStorageLocation(t *testing.T) {
func TestCreateUploadLocation(t *testing.T) {
dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)
t.Run("No project/domain", func(t *testing.T) {
Expand Down Expand Up @@ -113,8 +114,8 @@ func TestCreateUploadLocationMore(t *testing.T) {
}

assert.NoError(t, err)
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, &ds, taskExecutionManager)
assert.NoError(t, err)

Expand Down Expand Up @@ -171,15 +172,15 @@ func (t testMetadata) Exists() bool {

func TestCreateDownloadLink(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
nodeExecutionManager.SetGetNodeExecutionFunc(func(ctx context.Context, request *admin.NodeExecutionGetRequest) (*admin.NodeExecution, error) {
nodeExecutionManager := &mocks.NodeExecutionInterface{}
nodeExecutionManager.EXPECT().GetNodeExecution(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *admin.NodeExecutionGetRequest) (*admin.NodeExecution, error) {
return &admin.NodeExecution{
Closure: &admin.NodeExecutionClosure{
DeckUri: "s3://something/something",
},
}, nil
})
taskExecutionManager := &mocks.MockTaskExecutionManager{}
taskExecutionManager := &mocks.TaskExecutionInterface{}

s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)
Expand Down Expand Up @@ -262,8 +263,8 @@ func TestCreateDownloadLink(t *testing.T) {

func TestCreateDownloadLocation(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

Expand Down Expand Up @@ -300,8 +301,8 @@ func TestCreateDownloadLocation(t *testing.T) {

func TestService_GetData(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

Expand Down Expand Up @@ -340,15 +341,15 @@ func TestService_GetData(t *testing.T) {
},
}

nodeExecutionManager.SetGetNodeExecutionDataFunc(
nodeExecutionManager.EXPECT().GetNodeExecutionData(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, request *admin.NodeExecutionGetDataRequest) (*admin.NodeExecutionGetDataResponse, error) {
return &admin.NodeExecutionGetDataResponse{
FullInputs: inputsLM,
FullOutputs: outputsLM,
}, nil
},
)
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.EXPECT().ListTaskExecutions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return &admin.TaskExecutionList{
TaskExecutions: []*admin.TaskExecution{
{
Expand All @@ -374,7 +375,7 @@ func TestService_GetData(t *testing.T) {
},
}, nil
})
taskExecutionManager.SetGetTaskExecutionDataCallback(func(ctx context.Context, request *admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) {
taskExecutionManager.EXPECT().GetTaskExecutionData(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) {
return &admin.TaskExecutionGetDataResponse{
FullInputs: inputsLM,
FullOutputs: outputsLM,
Expand Down Expand Up @@ -441,13 +442,13 @@ func TestService_GetData(t *testing.T) {

func TestService_Error(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

t.Run("get a working set of urls without retry attempt", func(t *testing.T) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.EXPECT().ListTaskExecutions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return nil, errors.NewFlyteAdminErrorf(1, "not found")
})
nodeExecID := &core.NodeExecutionIdentifier{
Expand All @@ -463,7 +464,7 @@ func TestService_Error(t *testing.T) {
})

t.Run("get a working set of urls without retry attempt", func(t *testing.T) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.EXPECT().ListTaskExecutions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return &admin.TaskExecutionList{
TaskExecutions: nil,
Token: "",
Expand Down
25 changes: 13 additions & 12 deletions flyteadmin/pkg/async/schedule/aws/workflow_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/codes"

flyteAdminErrors "github.com/flyteorg/flyte/flyteadmin/pkg/errors"
Expand Down Expand Up @@ -138,8 +139,8 @@ func TestGetActiveLaunchPlanVersion(t *testing.T) {
Version: "foo",
}

launchPlanManager := mocks.NewMockLaunchPlanManager()
launchPlanManager.(*mocks.MockLaunchPlanManager).SetListLaunchPlansCallback(
launchPlanManager := mocks.LaunchPlanInterface{}
launchPlanManager.EXPECT().ListLaunchPlans(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, request *admin.ResourceListRequest) (
*admin.LaunchPlanList, error) {
assert.True(t, proto.Equal(launchPlanNamedIdentifier, request.GetId()))
Expand All @@ -153,7 +154,7 @@ func TestGetActiveLaunchPlanVersion(t *testing.T) {
},
}, nil
})
testExecutor := newWorkflowExecutorForTest(nil, nil, launchPlanManager)
testExecutor := newWorkflowExecutorForTest(nil, nil, &launchPlanManager)
launchPlan, err := testExecutor.getActiveLaunchPlanVersion(launchPlanNamedIdentifier)
assert.Nil(t, err)
assert.True(t, proto.Equal(&launchPlanIdentifier, launchPlan.GetId()))
Expand All @@ -167,13 +168,13 @@ func TestGetActiveLaunchPlanVersion_ManagerError(t *testing.T) {
}

expectedErr := errors.New("expected error")
launchPlanManager := mocks.NewMockLaunchPlanManager()
launchPlanManager.(*mocks.MockLaunchPlanManager).SetListLaunchPlansCallback(
launchPlanManager := mocks.LaunchPlanInterface{}
launchPlanManager.EXPECT().ListLaunchPlans(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, request *admin.ResourceListRequest) (
*admin.LaunchPlanList, error) {
return nil, expectedErr
})
testExecutor := newWorkflowExecutorForTest(nil, nil, launchPlanManager)
testExecutor := newWorkflowExecutorForTest(nil, nil, &launchPlanManager)
_, err := testExecutor.getActiveLaunchPlanVersion(launchPlanIdentifier)
assert.EqualError(t, err, expectedErr.Error())
}
Expand Down Expand Up @@ -229,23 +230,23 @@ func TestRun(t *testing.T) {
testSubscriber := pubsubtest.TestSubscriber{
JSONMessages: messages,
}
testExecutionManager := mocks.MockExecutionManager{}
testExecutionManager := mocks.ExecutionInterface{}
var messagesSeen int
testExecutionManager.SetCreateCallback(func(
testExecutionManager.EXPECT().CreateExecution(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(
ctx context.Context, request *admin.ExecutionCreateRequest, requestedAt time.Time) (
*admin.ExecutionCreateResponse, error) {
assert.Equal(t, "project", request.GetProject())
assert.Equal(t, "domain", request.GetDomain())
assert.Equal(t, "ar8fphnlc5wh9dksjncj", request.GetName())
if messagesSeen == 0 {
assert.Contains(t, request.GetInputs().GetLiterals(), testKickoffTime)
assert.Equal(t, testKickoffTimeProtoLiteral, request.GetInputs().GetLiterals()[testKickoffTime])
assert.True(t, proto.Equal(testKickoffTimeProtoLiteral, request.GetInputs().GetLiterals()[testKickoffTime]))
}
messagesSeen++
return &admin.ExecutionCreateResponse{}, nil
})
launchPlanManager := mocks.NewMockLaunchPlanManager()
launchPlanManager.(*mocks.MockLaunchPlanManager).SetListLaunchPlansCallback(
launchPlanManager := mocks.LaunchPlanInterface{}
launchPlanManager.EXPECT().ListLaunchPlans(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, request *admin.ResourceListRequest) (
*admin.LaunchPlanList, error) {
assert.Equal(t, "project", request.GetId().GetProject())
Expand Down Expand Up @@ -280,7 +281,7 @@ func TestRun(t *testing.T) {
},
}, nil
})
testExecutor := newWorkflowExecutorForTest(&testSubscriber, &testExecutionManager, launchPlanManager)
testExecutor := newWorkflowExecutorForTest(&testSubscriber, &testExecutionManager, &launchPlanManager)
err := testExecutor.run()
assert.Len(t, messages, messagesSeen)
assert.Nil(t, err)
Expand Down
17 changes: 10 additions & 7 deletions flyteadmin/pkg/clusterresource/impl/db_admin_data_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

Expand All @@ -28,9 +29,9 @@ func TestGetClusterResourceAttributes(t *testing.T) {
"K1": "V1",
"K2": "V2",
}
resourceManager := mocks.MockResourceManager{}
t.Run("happy case", func(t *testing.T) {
resourceManager.GetResourceFunc = func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
resourceManager := mocks.ResourceInterface{}
resourceManager.EXPECT().GetResource(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
return &interfaces.ResourceResponse{
Project: request.Project,
Domain: request.Domain,
Expand All @@ -43,7 +44,7 @@ func TestGetClusterResourceAttributes(t *testing.T) {
},
},
}, nil
}
})
provider := dbAdminProvider{
resourceManager: &resourceManager,
}
Expand All @@ -52,17 +53,19 @@ func TestGetClusterResourceAttributes(t *testing.T) {
assert.EqualValues(t, attrs.GetAttributes(), attributes)
})
t.Run("error", func(t *testing.T) {
resourceManager.GetResourceFunc = func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
resourceManager := mocks.ResourceInterface{}
resourceManager.EXPECT().GetResource(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
return nil, errFoo
}
})
provider := dbAdminProvider{
resourceManager: &resourceManager,
}
_, err := provider.GetClusterResourceAttributes(context.TODO(), project, domain)
assert.EqualError(t, err, errFoo.Error())
})
t.Run("weird db response", func(t *testing.T) {
resourceManager.GetResourceFunc = func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
resourceManager := mocks.ResourceInterface{}
resourceManager.EXPECT().GetResource(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
return &interfaces.ResourceResponse{
Project: request.Project,
Domain: request.Domain,
Expand All @@ -75,7 +78,7 @@ func TestGetClusterResourceAttributes(t *testing.T) {
},
},
}, nil
}
})
provider := dbAdminProvider{
resourceManager: &resourceManager,
}
Expand Down
Loading
Loading