Skip to content

Commit

Permalink
feat(task_repo): set default version to the latest
Browse files Browse the repository at this point in the history
Signed-off-by: featherchen <[email protected]>
  • Loading branch information
featherchen committed Oct 22, 2024
1 parent bdaf79f commit 2d88b0d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
20 changes: 20 additions & 0 deletions flyteadmin/pkg/manager/impl/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,23 @@ func ValidateIdentifierFieldsSet(id *core.Identifier) error {
return nil
}

// ValidateTaskIdentifierFieldsSet Validates that all required fields, except version, for a task identifier are present.
func ValidateTaskIdentifierFieldsSet(id *core.Identifier) error {
if id == nil {
return shared.GetMissingArgumentError(shared.ID)
}
if err := ValidateEmptyStringField(id.Project, shared.Project); err != nil {
return err
}
if err := ValidateEmptyStringField(id.Domain, shared.Domain); err != nil {
return err
}
if err := ValidateEmptyStringField(id.Name, shared.Name); err != nil {
return err
}
return nil
}

// ValidateIdentifier Validates that all required fields for an identifier are present.
func ValidateIdentifier(id *core.Identifier, expectedType common.Entity) error {
if id == nil {
Expand All @@ -105,6 +122,9 @@ func ValidateIdentifier(id *core.Identifier, expectedType common.Entity) error {
"unexpected resource type %s for identifier [%+v], expected %s instead",
strings.ToLower(id.ResourceType.String()), id, strings.ToLower(entityToResourceType[expectedType].String()))
}
if id.ResourceType == core.ResourceType_TASK {
return ValidateTaskIdentifierFieldsSet(id)
}
return ValidateIdentifierFieldsSet(id)
}

Expand Down
12 changes: 9 additions & 3 deletions flyteadmin/pkg/repositories/gormimpl/task_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,20 @@ func (r *TaskRepo) Create(ctx context.Context, input models.Task, descriptionEnt
func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Task, error) {
var task models.Task
timer := r.metrics.GetDuration.Start()
tx := r.db.WithContext(ctx).Where(&models.Task{
query := r.db.WithContext(ctx).Where(&models.Task{
TaskKey: models.TaskKey{
Project: input.Project,
Domain: input.Domain,
Name: input.Name,
Version: input.Version,
},
}).Take(&task)
})

if input.Version == "" {
query = query.Order("version DESC").Limit(1)
} else {
query = query.Where("version = ?", input.Version)
}
tx := query.Take(&task)
timer.Stop()
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return models.Task{}, flyteAdminDbErrors.GetMissingEntityError(core.ResourceType_TASK.String(), &core.Identifier{
Expand Down

0 comments on commit 2d88b0d

Please sign in to comment.