From 0b2bcc52dcd3f4ce3d25418c47af7b265514e3a8 Mon Sep 17 00:00:00 2001 From: Jakub Novak Date: Thu, 29 Aug 2024 16:29:43 +0200 Subject: [PATCH] Improve status code for template access check --- .../api/internal/cache/templates/cache.go | 20 ++++++++++--------- .../api/internal/handlers/sandbox_create.go | 4 ++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/packages/api/internal/cache/templates/cache.go b/packages/api/internal/cache/templates/cache.go index 47a73045..fb2600c3 100644 --- a/packages/api/internal/cache/templates/cache.go +++ b/packages/api/internal/cache/templates/cache.go @@ -3,6 +3,7 @@ package templatecache import ( "context" "fmt" + "net/http" "time" "github.com/google/uuid" @@ -35,14 +36,14 @@ func NewAliasCache() *AliasCache { } } -func (c *AliasCache) Get(alias string) (templateID string, err error) { +func (c *AliasCache) Get(alias string) (templateID string, found bool) { item := c.cache.Get(alias) if item == nil { - return "", fmt.Errorf("alias not found") + return "", false } - return item.Value(), nil + return item.Value(), true } type TemplateCache struct { @@ -63,20 +64,21 @@ func NewTemplateCache(db *db.DB) *TemplateCache { } } -func (c *TemplateCache) Get(ctx context.Context, aliasOrEnvID string, teamID uuid.UUID, public bool) (env *api.Template, build *models.EnvBuild, err error) { +func (c *TemplateCache) Get(ctx context.Context, aliasOrEnvID string, teamID uuid.UUID, public bool) (env *api.Template, build *models.EnvBuild, apiErr *api.APIError) { var envDB *db.Template var item *ttlcache.Item[string, *TemplateInfo] var templateInfo *TemplateInfo + var err error - templateID, err := c.aliasCache.Get(aliasOrEnvID) - if err == nil { + templateID, found := c.aliasCache.Get(aliasOrEnvID) + if found == true { item = c.cache.Get(templateID) } if item == nil { envDB, build, err = c.db.GetEnv(ctx, aliasOrEnvID) if err != nil { - return nil, nil, fmt.Errorf("error when getting template: %w", err) + return nil, nil, &api.APIError{Code: http.StatusInternalServerError, ClientMsg: fmt.Sprintf("error when getting template: %v", err), Err: err} } c.aliasCache.cache.Set(envDB.TemplateID, envDB.TemplateID, templateInfoExpiration) @@ -88,7 +90,7 @@ func (c *TemplateCache) Get(ctx context.Context, aliasOrEnvID string, teamID uui // Check if the team has access to the environment if envDB.TeamID != teamID && (!public || !envDB.Public) { - return nil, nil, fmt.Errorf("team '%s' does not have access to the template '%s'", teamID, aliasOrEnvID) + return nil, nil, &api.APIError{Code: http.StatusForbidden, ClientMsg: fmt.Sprintf("Team '%s' does not have access to the template '%s'", teamID, aliasOrEnvID), Err: fmt.Errorf("team '%s' does not have access to the template '%s'", teamID, aliasOrEnvID)} } templateInfo = &TemplateInfo{template: &api.Template{ @@ -103,7 +105,7 @@ func (c *TemplateCache) Get(ctx context.Context, aliasOrEnvID string, teamID uui templateInfo = item.Value() if templateInfo.teamID != teamID && !templateInfo.template.Public { - return nil, nil, fmt.Errorf("team '%s' does not have access to the template '%s'", teamID, aliasOrEnvID) + return nil, nil, &api.APIError{Code: http.StatusForbidden, ClientMsg: fmt.Sprintf("Team '%s' does not have access to the template '%s'", teamID, aliasOrEnvID), Err: fmt.Errorf("team '%s' does not have access to the template '%s'", teamID, aliasOrEnvID)} } } diff --git a/packages/api/internal/handlers/sandbox_create.go b/packages/api/internal/handlers/sandbox_create.go index 09c77bfc..3eaac306 100644 --- a/packages/api/internal/handlers/sandbox_create.go +++ b/packages/api/internal/handlers/sandbox_create.go @@ -69,10 +69,10 @@ func (a *APIStore) PostSandboxes(c *gin.Context) { // Check if team has access to the environment env, build, checkErr := a.templateCache.Get(ctx, cleanedAliasOrEnvID, team.ID, true) if checkErr != nil { - errMsg := fmt.Errorf("error when checking team access: %w", checkErr) + errMsg := fmt.Errorf("error when checking team access: %s", checkErr.Err) telemetry.ReportCriticalError(ctx, errMsg) - a.sendAPIStoreError(c, http.StatusInternalServerError, fmt.Sprintf("Error when checking team access: %s", checkErr)) + a.sendAPIStoreError(c, checkErr.Code, fmt.Sprintf("Error when checking team access: %s", checkErr.ClientMsg)) return }