Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement snowpipe destination config validation #5472

Merged
merged 4 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/rudderlabs/rudder-server/utils/timeutil"
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
"github.com/rudderlabs/rudder-server/warehouse/validations"
)

var json = jsoniter.ConfigCompatibleWithStandardLibrary
Expand All @@ -57,6 +58,7 @@ func New(
now: timeutil.Now,
channelCache: sync.Map{},
polledImportInfoMap: make(map[string]*importInfo),
validator: validations.NewDestinationValidator(),
}

m.config.client.url = conf.GetString("SnowpipeStreaming.Client.URL", "http://localhost:9078")
Expand Down Expand Up @@ -138,6 +140,15 @@ func (m *Manager) retryableClient() *retryablehttp.Client {
return client
}

func (m *Manager) validateConfig(ctx context.Context, dest *backendconfig.DestinationT) error {
dest.Config["useKeyPairAuth"] = true // Since we are currently only supporting key pair auth
response := m.validator.Validate(ctx, dest)
if response.Success {
return nil
}
return errors.New(response.Error)
}

func (m *Manager) Now() time.Time {
return m.now()
}
Expand Down Expand Up @@ -176,6 +187,10 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU
switch {
case errors.Is(err, errAuthz):
m.setBackOff(err)
validationError := m.validateConfig(ctx, asyncDest.Destination)
if validationError != nil {
err = fmt.Errorf("failed to validate snowpipe credentials: %s", validationError.Error())
}
return m.failedJobs(asyncDest, err.Error())
case errors.Is(err, errBackoff):
return m.failedJobs(asyncDest, err.Error())
Expand Down Expand Up @@ -225,6 +240,10 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU
if !isBackoffSet {
isBackoffSet = true
m.setBackOff(err)
validationError := m.validateConfig(ctx, asyncDest.Destination)
if validationError != nil && failedReason == "" {
failedReason = fmt.Sprintf("failed to validate snowpipe credentials: %s", validationError.Error())
}
}
case errors.Is(err, errBackoff):
shouldResetBackoff = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
"github.com/rudderlabs/rudder-server/warehouse/integrations/snowflake"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
"github.com/rudderlabs/rudder-server/warehouse/validations"
)

type mockAPI struct {
Expand Down Expand Up @@ -68,6 +69,22 @@ func (m *mockManager) CreateTable(context.Context, string, whutils.ModelTableSch
return nil
}

type mockValidator struct {
err error
}

func (m *mockValidator) Validate(_ context.Context, _ *backendconfig.DestinationT) *validations.DestinationValidationResponse {
if m.err != nil {
return &validations.DestinationValidationResponse{
Success: false,
Error: m.err.Error(),
}
}
return &validations.DestinationValidationResponse{
Success: true,
}
}

var (
usersChannelResponse = &model.ChannelResponse{
ChannelID: "test-users-channel",
Expand Down Expand Up @@ -104,6 +121,7 @@ func TestSnowpipeStreaming(t *testing.T) {
},
Config: make(map[string]interface{}),
}
validations.Init()

t.Run("Upload with invalid file path", func(t *testing.T) {
statsStore, err := memstats.New()
Expand Down Expand Up @@ -405,34 +423,99 @@ func TestSnowpipeStreaming(t *testing.T) {
require.False(t, sm.isInBackoff())
})

t.Run("Upload with discards table authorization error should mark the job as failed", func(t *testing.T) {
statsStore, err := memstats.New()
require.NoError(t, err)
t.Run("destination config validation", func(t *testing.T) {
testCases := []struct {
name string
validationError error
expectedFailedReason string
}{
{
name: "should return validation error",
validationError: fmt.Errorf("missing permissions to do xyz"),
expectedFailedReason: "missing permissions to do xyz",
},
{
name: "should not return any error",
validationError: nil,
expectedFailedReason: "failed to create schema",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
sm := New(config.New(), logger.NOP, stats.NOP, destination)
sm.channelCache.Store("RUDDER_DISCARDS", rudderDiscardsChannelResponse)
sm.api = &mockAPI{
createChannelOutputMap: map[string]func() (*model.ChannelResponse, error){
"USERS": func() (*model.ChannelResponse, error) {
return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil
},
},
}
sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) {
sf := snowflake.New(config.New(), logger.NOP, stats.NOP)
mm := newMockManager(sf)
mm.createSchemaErr = fmt.Errorf("failed to create schema")
return mm, nil
}
sm.validator = &mockValidator{err: tc.validationError}
asyncDestStruct := &common.AsyncDestinationStruct{
Destination: destination,
FileName: "testdata/successful_user_records.txt",
}
output := sm.Upload(asyncDestStruct)
require.Equal(t, 2, output.FailedCount)
require.Equal(t, 0, output.AbortCount)
require.Contains(t, output.FailedReason, tc.expectedFailedReason)
})
}
})

sm := New(config.New(), logger.NOP, statsStore, destination)
sm.api = &mockAPI{
createChannelOutputMap: map[string]func() (*model.ChannelResponse, error){
"RUDDER_DISCARDS": func() (*model.ChannelResponse, error) {
return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil
},
t.Run("Upload with discards table authorization error should mark the job as failed", func(t *testing.T) {
testCases := []struct {
name string
validationError error
expectedFailedReason string
}{
{
name: "authorization error",
validationError: fmt.Errorf("authorization error"),
expectedFailedReason: "failed to validate snowpipe credentials: authorization error",
},
{
name: "other error",
validationError: nil,
expectedFailedReason: "failed to create schema",
},
}
sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) {
sf := snowflake.New(config.New(), logger.NOP, stats.NOP)
mm := newMockManager(sf)
mm.createSchemaErr = fmt.Errorf("failed to create schema")
return mm, nil
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
sm := New(config.New(), logger.NOP, stats.NOP, destination)
sm.api = &mockAPI{
createChannelOutputMap: map[string]func() (*model.ChannelResponse, error){
"RUDDER_DISCARDS": func() (*model.ChannelResponse, error) {
return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil
},
},
}
sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) {
sf := snowflake.New(config.New(), logger.NOP, stats.NOP)
mm := newMockManager(sf)
mm.createSchemaErr = fmt.Errorf("failed to create schema")
return mm, nil
}
sm.validator = &mockValidator{err: tc.validationError}
output := sm.Upload(&common.AsyncDestinationStruct{
ImportingJobIDs: []int64{1},
Destination: destination,
FileName: "testdata/successful_user_records.txt",
})
require.Equal(t, 1, output.FailedCount)
require.Equal(t, 0, output.AbortCount)
require.Contains(t, output.FailedReason, tc.expectedFailedReason)
require.Empty(t, output.AbortReason)
require.Equal(t, true, sm.isInBackoff())
})
}
output := sm.Upload(&common.AsyncDestinationStruct{
ImportingJobIDs: []int64{1},
Destination: destination,
FileName: "testdata/successful_user_records.txt",
})
require.Equal(t, 1, output.FailedCount)
require.Equal(t, 0, output.AbortCount)
require.NotEmpty(t, output.FailedReason)
require.Empty(t, output.AbortReason)
require.Equal(t, true, sm.isInBackoff())
})

t.Run("Upload insert error for all events", func(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
backendconfig "github.com/rudderlabs/rudder-server/backend-config"
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
"github.com/rudderlabs/rudder-server/warehouse/validations"
)

type (
Expand All @@ -31,6 +32,7 @@ type (
api api
channelCache sync.Map
polledImportInfoMap map[string]*importInfo
validator validations.DestinationValidator

config struct {
client struct {
Expand Down
6 changes: 0 additions & 6 deletions warehouse/internal/model/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,3 @@ type Step struct {
type StepsResponse struct {
Steps []*Step `json:"steps"`
}

type DestinationValidationResponse struct {
Success bool `json:"success"`
Error string `json:"error"`
Steps []*Step `json:"steps"`
}
88 changes: 38 additions & 50 deletions warehouse/validations/steps.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"encoding/json"

"github.com/samber/lo"

backendconfig "github.com/rudderlabs/rudder-server/backend-config"
schemarepository "github.com/rudderlabs/rudder-server/warehouse/integrations/datalake/schema-repository"
"github.com/rudderlabs/rudder-server/warehouse/internal/model"
Expand All @@ -15,63 +17,49 @@ func validateStepFunc(_ context.Context, destination *backendconfig.DestinationT
}

func StepsToValidate(dest *backendconfig.DestinationT) *model.StepsResponse {
var (
destType = dest.DestinationDefinition.Name
steps []*model.Step
)
destType := dest.DestinationDefinition.Name

if destType == warehouseutils.SnowpipeStreaming {
return &model.StepsResponse{
Steps: []*model.Step{
{ID: 1, Name: model.VerifyingConnections},
{ID: 2, Name: model.VerifyingCreateSchema},
{ID: 3, Name: model.VerifyingCreateAndAlterTable},
{ID: 4, Name: model.VerifyingFetchSchema},
},
}
}

steps = []*model.Step{{
ID: len(steps) + 1,
Name: model.VerifyingObjectStorage,
}}
steps := []*model.Step{
{ID: 1, Name: model.VerifyingObjectStorage},
}

appendSteps := func(newSteps ...string) {
for _, step := range newSteps {
steps = append(steps, &model.Step{ID: len(steps) + 1, Name: step})
}
}

switch destType {
case warehouseutils.GCSDatalake, warehouseutils.AzureDatalake:
// No additional steps
case warehouseutils.S3Datalake:
wh := createDummyWarehouse(dest)
if canUseGlue := schemarepository.UseGlue(&wh); !canUseGlue {
break
if schemarepository.UseGlue(lo.ToPtr(createDummyWarehouse(dest))) {
appendSteps(
model.VerifyingCreateSchema,
model.VerifyingCreateAndAlterTable,
model.VerifyingFetchSchema,
)
}

steps = append(steps,
&model.Step{
ID: len(steps) + 1,
Name: model.VerifyingCreateSchema,
},
&model.Step{
ID: len(steps) + 2,
Name: model.VerifyingCreateAndAlterTable,
},
&model.Step{
ID: len(steps) + 3,
Name: model.VerifyingFetchSchema,
},
)
default:
steps = append(steps,
&model.Step{
ID: len(steps) + 1,
Name: model.VerifyingConnections,
},
&model.Step{
ID: len(steps) + 2,
Name: model.VerifyingCreateSchema,
},
&model.Step{
ID: len(steps) + 3,
Name: model.VerifyingCreateAndAlterTable,
},
&model.Step{
ID: len(steps) + 4,
Name: model.VerifyingFetchSchema,
},
&model.Step{
ID: len(steps) + 5,
Name: model.VerifyingLoadTable,
},
appendSteps(
model.VerifyingConnections,
model.VerifyingCreateSchema,
model.VerifyingCreateAndAlterTable,
model.VerifyingFetchSchema,
model.VerifyingLoadTable,
)
}
return &model.StepsResponse{
Steps: steps,
}

return &model.StepsResponse{Steps: steps}
}
14 changes: 14 additions & 0 deletions warehouse/validations/steps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ func TestValidationSteps(t *testing.T) {
model.VerifyingLoadTable,
},
},
{
name: "Snowpipe",
dest: backendconfig.DestinationT{
DestinationDefinition: backendconfig.DestinationDefinitionT{
Name: warehouseutils.SnowpipeStreaming,
},
},
steps: []string{
model.VerifyingConnections,
model.VerifyingCreateSchema,
model.VerifyingCreateAndAlterTable,
model.VerifyingFetchSchema,
},
},
}

for _, tc := range testCases {
Expand Down
Loading
Loading