From 2d88b0df4b3fec469ccf220a2a214032c435cbce Mon Sep 17 00:00:00 2001 From: featherchen Date: Tue, 22 Oct 2024 14:14:27 -0700 Subject: [PATCH] feat(task_repo): set default version to the latest Signed-off-by: featherchen --- .../pkg/manager/impl/validation/validation.go | 20 +++++++++++++++++++ .../pkg/repositories/gormimpl/task_repo.go | 12 ++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/validation/validation.go b/flyteadmin/pkg/manager/impl/validation/validation.go index de2927495c7..2ff5859b44e 100644 --- a/flyteadmin/pkg/manager/impl/validation/validation.go +++ b/flyteadmin/pkg/manager/impl/validation/validation.go @@ -95,6 +95,23 @@ func ValidateIdentifierFieldsSet(id *core.Identifier) error { return nil } +// ValidateTaskIdentifierFieldsSet Validates that all required fields, except version, for a task identifier are present. +func ValidateTaskIdentifierFieldsSet(id *core.Identifier) error { + if id == nil { + return shared.GetMissingArgumentError(shared.ID) + } + if err := ValidateEmptyStringField(id.Project, shared.Project); err != nil { + return err + } + if err := ValidateEmptyStringField(id.Domain, shared.Domain); err != nil { + return err + } + if err := ValidateEmptyStringField(id.Name, shared.Name); err != nil { + return err + } + return nil +} + // ValidateIdentifier Validates that all required fields for an identifier are present. func ValidateIdentifier(id *core.Identifier, expectedType common.Entity) error { if id == nil { @@ -105,6 +122,9 @@ func ValidateIdentifier(id *core.Identifier, expectedType common.Entity) error { "unexpected resource type %s for identifier [%+v], expected %s instead", strings.ToLower(id.ResourceType.String()), id, strings.ToLower(entityToResourceType[expectedType].String())) } + if id.ResourceType == core.ResourceType_TASK { + return ValidateTaskIdentifierFieldsSet(id) + } return ValidateIdentifierFieldsSet(id) } diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index 1b42756b7a6..784bb306b78 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -49,14 +49,20 @@ func (r *TaskRepo) Create(ctx context.Context, input models.Task, descriptionEnt func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Task, error) { var task models.Task timer := r.metrics.GetDuration.Start() - tx := r.db.WithContext(ctx).Where(&models.Task{ + query := r.db.WithContext(ctx).Where(&models.Task{ TaskKey: models.TaskKey{ Project: input.Project, Domain: input.Domain, Name: input.Name, - Version: input.Version, }, - }).Take(&task) + }) + + if input.Version == "" { + query = query.Order("version DESC").Limit(1) + } else { + query = query.Where("version = ?", input.Version) + } + tx := query.Take(&task) timer.Stop() if errors.Is(tx.Error, gorm.ErrRecordNotFound) { return models.Task{}, flyteAdminDbErrors.GetMissingEntityError(core.ResourceType_TASK.String(), &core.Identifier{