diff --git a/api/handler/code.go b/api/handler/code.go index df6bff7f..320c84bb 100644 --- a/api/handler/code.go +++ b/api/handler/code.go @@ -71,6 +71,10 @@ func (h *CodeHandler) Create(ctx *gin.Context) { code, err := h.code.Create(ctx.Request.Context(), req) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to create code", slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -189,6 +193,10 @@ func (h *CodeHandler) Update(ctx *gin.Context) { code, err := h.code.Update(ctx.Request.Context(), req) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to update code", slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -226,6 +234,10 @@ func (h *CodeHandler) Delete(ctx *gin.Context) { } err = h.code.Delete(ctx.Request.Context(), namespace, name, currentUser) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to delete code", slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -258,8 +270,8 @@ func (h *CodeHandler) Show(ctx *gin.Context) { currentUser := httpbase.GetCurrentUser(ctx) detail, err := h.code.Show(ctx.Request.Context(), namespace, name, currentUser) if err != nil { - if errors.Is(err, component.ErrUnauthorized) { - httpbase.UnauthorizedError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) return } slog.Error("Failed to get code", slog.Any("error", err)) @@ -293,8 +305,8 @@ func (h *CodeHandler) Relations(ctx *gin.Context) { currentUser := httpbase.GetCurrentUser(ctx) detail, err := h.code.Relations(ctx.Request.Context(), namespace, name, currentUser) if err != nil { - if errors.Is(err, component.ErrUnauthorized) { - httpbase.UnauthorizedError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) return } slog.Error("Failed to get code relations", slog.Any("error", err)) diff --git a/api/handler/dataset.go b/api/handler/dataset.go index 7c0e84d9..3789bcf1 100644 --- a/api/handler/dataset.go +++ b/api/handler/dataset.go @@ -84,6 +84,10 @@ func (h *DatasetHandler) Create(ctx *gin.Context) { dataset, err := h.dataset.Create(ctx.Request.Context(), req) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to create dataset", slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -202,6 +206,10 @@ func (h *DatasetHandler) Update(ctx *gin.Context) { dataset, err := h.dataset.Update(ctx.Request.Context(), req) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to update dataset", slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -239,6 +247,10 @@ func (h *DatasetHandler) Delete(ctx *gin.Context) { } err = h.dataset.Delete(ctx.Request.Context(), namespace, name, currentUser) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to delete dataset", slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -271,8 +283,8 @@ func (h *DatasetHandler) Show(ctx *gin.Context) { currentUser := httpbase.GetCurrentUser(ctx) detail, err := h.dataset.Show(ctx.Request.Context(), namespace, name, currentUser) if err != nil { - if errors.Is(err, component.ErrUnauthorized) { - httpbase.UnauthorizedError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) return } slog.Error("Failed to get dataset", slog.Any("error", err)) @@ -307,8 +319,8 @@ func (h *DatasetHandler) Relations(ctx *gin.Context) { currentUser := httpbase.GetCurrentUser(ctx) detail, err := h.dataset.Relations(ctx.Request.Context(), namespace, name, currentUser) if err != nil { - if errors.Is(err, component.ErrUnauthorized) { - httpbase.UnauthorizedError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) return } slog.Error("Failed to get dataset relations", slog.Any("error", err)) @@ -357,8 +369,8 @@ func (h *DatasetHandler) AllFiles(ctx *gin.Context) { req.Ref = "" detail, err := h.repo.AllFiles(ctx.Request.Context(), req) if err != nil { - if errors.Is(err, component.ErrUnauthorized) { - httpbase.UnauthorizedError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) return } slog.Error("Failed to get dataset all files", slog.Any("error", err)) diff --git a/api/handler/dataset_test.go b/api/handler/dataset_test.go index 6e9d9c76..f71b171e 100644 --- a/api/handler/dataset_test.go +++ b/api/handler/dataset_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/component" ) type DatasetTester struct { @@ -86,78 +87,167 @@ func TestDatasetHandler_Index(t *testing.T) { } func TestDatasetHandler_Update(t *testing.T) { - tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { - return h.Update + t.Run("forbidden", func(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Update + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.UpdateDatasetReq{}).Return(true, nil) + tester.mocks.dataset.EXPECT().Update(tester.ctx, &types.UpdateDatasetReq{ + UpdateRepoReq: types.UpdateRepoReq{ + Username: "u", + Namespace: "u-other", + Name: "r", + }, + }).Return(nil, component.ErrForbiddenMsg("user not allowed to update dataset")) + tester.WithParam("namespace", "u-other").WithParam("name", "r"). + WithBody(t, &types.UpdateDatasetReq{ + UpdateRepoReq: types.UpdateRepoReq{Name: "r"}, + }). + WithUser(). + Execute() + + require.Equal(t, 403, tester.response.Code) + }) + + t.Run("normal", func(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Update + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.UpdateDatasetReq{}).Return(true, nil) + tester.mocks.dataset.EXPECT().Update(tester.ctx, &types.UpdateDatasetReq{ + UpdateRepoReq: types.UpdateRepoReq{ + Username: "u", + Namespace: "u", + Name: "r", + }, + }).Return(&types.Dataset{Name: "foo"}, nil) + tester.WithBody(t, &types.UpdateDatasetReq{ + UpdateRepoReq: types.UpdateRepoReq{Name: "r"}, + }).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Dataset{Name: "foo"}) }) - tester.RequireUser(t) - - tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.UpdateDatasetReq{}).Return(true, nil) - tester.mocks.dataset.EXPECT().Update(tester.ctx, &types.UpdateDatasetReq{ - UpdateRepoReq: types.UpdateRepoReq{ - Username: "u", - Namespace: "u", - Name: "r", - }, - }).Return(&types.Dataset{Name: "foo"}, nil) - tester.WithBody(t, &types.UpdateDatasetReq{ - UpdateRepoReq: types.UpdateRepoReq{Name: "r"}, - }).Execute() - - tester.ResponseEq(t, 200, tester.OKText, &types.Dataset{Name: "foo"}) } func TestDatasetHandler_Delete(t *testing.T) { - tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { - return h.Delete + t.Run("forbidden", func(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Delete + }) + tester.RequireUser(t) + + tester.mocks.dataset.EXPECT().Delete(tester.ctx, "u-other", "r", "u").Return(component.ErrForbidden) + tester.WithParam("namespace", "u-other").WithParam("name", "r") + tester.WithUser().Execute() + + require.Equal(t, 403, tester.response.Code) }) - tester.RequireUser(t) - tester.mocks.dataset.EXPECT().Delete(tester.ctx, "u", "r", "u").Return(nil) - tester.Execute() + t.Run("normal", func(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Delete + }) + tester.RequireUser(t) - tester.ResponseEq(t, 200, tester.OKText, nil) + tester.mocks.dataset.EXPECT().Delete(tester.ctx, "u", "r", "u").Return(nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) + }) } func TestDatasetHandler_Show(t *testing.T) { - tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { - return h.Show + t.Run("forbidden", func(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Show + }) + + tester.mocks.dataset.EXPECT().Show(tester.ctx, "u-other", "r", "u").Return(nil, component.ErrForbidden) + tester.WithParam("namespace", "u-other").WithParam("name", "r") + tester.WithUser().Execute() + + require.Equal(t, 403, tester.response.Code) }) - tester.mocks.dataset.EXPECT().Show(tester.ctx, "u", "r", "u").Return(&types.Dataset{ - Name: "d", - }, nil) - tester.WithUser().Execute() + t.Run("normal", func(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Show + }) + + tester.mocks.dataset.EXPECT().Show(tester.ctx, "u", "r", "u").Return(&types.Dataset{ + Name: "d", + }, nil) + tester.WithUser().Execute() - tester.ResponseEq(t, 200, tester.OKText, &types.Dataset{Name: "d"}) + tester.ResponseEq(t, 200, tester.OKText, &types.Dataset{Name: "d"}) + }) } func TestDatasetHandler_Relations(t *testing.T) { - tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { - return h.Relations + t.Run("forbidden", func(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Relations + }) + + tester.mocks.dataset.EXPECT().Relations(tester.ctx, "u-other", "r", "u").Return(nil, component.ErrForbidden) + tester.WithParam("namespace", "u-other").WithParam("name", "r") + tester.WithUser().Execute() + + require.Equal(t, 403, tester.response.Code) + }) - tester.mocks.dataset.EXPECT().Relations(tester.ctx, "u", "r", "u").Return(&types.Relations{ - Models: []*types.Model{{Name: "m"}}, - }, nil) - tester.WithUser().Execute() + t.Run("normal", func(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Relations + }) + + tester.mocks.dataset.EXPECT().Relations(tester.ctx, "u", "r", "u").Return(&types.Relations{ + Models: []*types.Model{{Name: "m"}}, + }, nil) + tester.WithUser().Execute() - tester.ResponseEq(t, 200, tester.OKText, &types.Relations{ - Models: []*types.Model{{Name: "m"}}, + tester.ResponseEq(t, 200, tester.OKText, &types.Relations{ + Models: []*types.Model{{Name: "m"}}, + }) }) } func TestDatasetHandler_AllFiles(t *testing.T) { - tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { - return h.AllFiles + t.Run("forbidden", func(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.AllFiles + }) + + tester.mocks.repo.EXPECT().AllFiles(tester.ctx, types.GetAllFilesReq{ + Namespace: "u-other", + Name: "r", + RepoType: types.DatasetRepo, + CurrentUser: "u", + }).Return(nil, component.ErrForbidden) + tester.WithParam("namespace", "u-other").WithParam("name", "r") + tester.WithUser().Execute() + + require.Equal(t, 403, tester.response.Code) }) - tester.mocks.repo.EXPECT().AllFiles(tester.ctx, types.GetAllFilesReq{ - Namespace: "u", - Name: "r", - RepoType: types.DatasetRepo, - CurrentUser: "u", - }).Return([]*types.File{{Name: "f"}}, nil) - tester.WithUser().Execute() + t.Run("normal", func(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.AllFiles + }) + + tester.mocks.repo.EXPECT().AllFiles(tester.ctx, types.GetAllFilesReq{ + Namespace: "u", + Name: "r", + RepoType: types.DatasetRepo, + CurrentUser: "u", + }).Return([]*types.File{{Name: "f"}}, nil) + tester.WithUser().Execute() - tester.ResponseEq(t, 200, tester.OKText, []*types.File{{Name: "f"}}) + tester.ResponseEq(t, 200, tester.OKText, []*types.File{{Name: "f"}}) + }) } diff --git a/api/handler/git_http.go b/api/handler/git_http.go index fabc771f..75fc60d1 100644 --- a/api/handler/git_http.go +++ b/api/handler/git_http.go @@ -111,6 +111,10 @@ func (h *GitHTTPHandler) GitUploadPack(ctx *gin.Context) { err := h.gitHttp.GitUploadPack(ctx.Request.Context(), req) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } httpbase.ServerError(ctx, err) return } diff --git a/api/handler/git_http_test.go b/api/handler/git_http_test.go index 512f3a3d..8f9ca468 100644 --- a/api/handler/git_http_test.go +++ b/api/handler/git_http_test.go @@ -12,6 +12,7 @@ import ( mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/component" ) type GitHTTPTester struct { @@ -71,27 +72,49 @@ func TestGitHTTPHandler_InfoRefs(t *testing.T) { } func TestGitHTTPHandler_GitUploadPack(t *testing.T) { - tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { - return h.GitUploadPack - }) + t.Run("normal", func(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.GitUploadPack + }) - tester.mocks.gitHttp.EXPECT().GitUploadPack(tester.ctx, types.GitUploadPackReq{ - Namespace: "u", - Name: "r", - RepoType: types.ModelRepo, - GitProtocol: "ssh", - Request: tester.gctx.Request, - Writer: tester.gctx.Writer, - CurrentUser: "u", - }).Return(nil) - tester.SetPath("git").WithQuery("service", "git-upload-pack").WithHeader("Git-Protocol", "ssh") - tester.WithKV("namespace", "u").WithKV("name", "r") - tester.WithKV("repo_type", "model").WithUser().WithHeader("Accept-Encoding", "gzip").Execute() + tester.mocks.gitHttp.EXPECT().GitUploadPack(tester.ctx, types.GitUploadPackReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + GitProtocol: "ssh", + Request: tester.gctx.Request, + Writer: tester.gctx.Writer, + CurrentUser: "u", + }).Return(nil) + tester.SetPath("git").WithQuery("service", "git-upload-pack").WithHeader("Git-Protocol", "ssh") + tester.WithKV("namespace", "u").WithKV("name", "r") + tester.WithKV("repo_type", "model").WithUser().WithHeader("Accept-Encoding", "gzip").Execute() + + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "application/x-git-result", headers.Get("Content-Type")) + require.Equal(t, "no-cache", headers.Get("Cache-Control")) + }) - require.Equal(t, 200, tester.response.Code) - headers := tester.response.Header() - require.Equal(t, "application/x-git-result", headers.Get("Content-Type")) - require.Equal(t, "no-cache", headers.Get("Cache-Control")) + t.Run("no permission", func(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.GitUploadPack + }) + tester.mocks.gitHttp.EXPECT().GitUploadPack(tester.ctx, types.GitUploadPackReq{ + Namespace: "u-other", + Name: "r", + RepoType: types.ModelRepo, + GitProtocol: "ssh", + Request: tester.gctx.Request, + Writer: tester.gctx.Writer, + CurrentUser: "u", + }).Return(component.ErrForbidden) + tester.SetPath("git").WithQuery("service", "git-upload-pack").WithHeader("Git-Protocol", "ssh") + tester.WithKV("namespace", "u-other").WithKV("name", "r") + tester.WithKV("repo_type", "model").WithUser().WithHeader("Accept-Encoding", "gzip").Execute() + + require.Equal(t, 403, tester.response.Code) + }) } func TestGitHTTPHandler_GitReceivePack(t *testing.T) { diff --git a/api/handler/model.go b/api/handler/model.go index a69fe1bd..2cd70967 100644 --- a/api/handler/model.go +++ b/api/handler/model.go @@ -143,6 +143,10 @@ func (h *ModelHandler) Create(ctx *gin.Context) { model, err := h.model.Create(ctx.Request.Context(), req) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to create model", slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -198,6 +202,10 @@ func (h *ModelHandler) Update(ctx *gin.Context) { model, err := h.model.Update(ctx.Request.Context(), req) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to update model", slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -235,6 +243,10 @@ func (h *ModelHandler) Delete(ctx *gin.Context) { } err = h.model.Delete(ctx.Request.Context(), namespace, name, currentUser) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to delete model", slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -268,8 +280,8 @@ func (h *ModelHandler) Show(ctx *gin.Context) { detail, err := h.model.Show(ctx.Request.Context(), namespace, name, currentUser, false) if err != nil { - if errors.Is(err, component.ErrUnauthorized) { - httpbase.UnauthorizedError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) return } slog.Error("Failed to get model detail", slog.Any("error", err)) @@ -296,8 +308,8 @@ func (h *ModelHandler) SDKModelInfo(ctx *gin.Context) { currentUser := httpbase.GetCurrentUser(ctx) modelInfo, err := h.model.SDKModelInfo(ctx.Request.Context(), namespace, name, ref, currentUser) if err != nil { - if errors.Is(err, component.ErrUnauthorized) { - httpbase.UnauthorizedError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) return } slog.Error("Failed to get sdk model info", slog.String("namespace", namespace), slog.String("name", name), slog.Any("error", err)) @@ -331,8 +343,8 @@ func (h *ModelHandler) Relations(ctx *gin.Context) { currentUser := httpbase.GetCurrentUser(ctx) detail, err := h.model.Relations(ctx.Request.Context(), namespace, name, currentUser) if err != nil { - if errors.Is(err, component.ErrUnauthorized) { - httpbase.UnauthorizedError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) return } slog.Error("Failed to get model relations", slog.Any("error", err)) @@ -383,6 +395,10 @@ func (h *ModelHandler) SetRelations(ctx *gin.Context) { err = h.model.SetRelationDatasets(ctx.Request.Context(), req) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to set datasets for model", slog.Any("req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -430,6 +446,10 @@ func (h *ModelHandler) AddDatasetRelation(ctx *gin.Context) { err = h.model.AddRelationDataset(ctx.Request.Context(), req) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to add dataset for model", slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -477,6 +497,10 @@ func (h *ModelHandler) DelDatasetRelation(ctx *gin.Context) { err = h.model.DelRelationDataset(ctx.Request.Context(), req) if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("Failed to delete dataset for model", slog.Any("error", err)) httpbase.ServerError(ctx, err) return @@ -579,7 +603,7 @@ func (h *ModelHandler) DeployDedicated(ctx *gin.Context) { if !allow { slog.Info("user not allowed to run model", slog.String("namespace", namespace), slog.String("name", name), slog.Any("username", currentUser)) - httpbase.UnauthorizedError(ctx, errors.New("user not allowed to run model")) + httpbase.ForbiddenError(ctx, errors.New("user not allowed to run model")) return } @@ -666,7 +690,7 @@ func (h *ModelHandler) FinetuneCreate(ctx *gin.Context) { if !allow { slog.Info("user is not allowed to run model", slog.String("namespace", namespace), slog.String("name", name), slog.Any("username", currentUser)) - httpbase.UnauthorizedError(ctx, errors.New("user not allowed to run model")) + httpbase.ForbiddenError(ctx, errors.New("user not allowed to run model")) return } @@ -759,7 +783,12 @@ func (h *ModelHandler) DeployDelete(ctx *gin.Context) { } err = h.repo.DeleteDeploy(ctx.Request.Context(), delReq) if err != nil { - slog.Error("Failed to delete deploy", slog.Any("error", err)) + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to delete inference", slog.Any("error", err), slog.Any("req", delReq)) + httpbase.ForbiddenError(ctx, err) + return + } + slog.Error("Failed to delete inference", slog.Any("error", err), slog.Any("req", delReq)) httpbase.ServerError(ctx, err) return } @@ -814,7 +843,12 @@ func (h *ModelHandler) FinetuneDelete(ctx *gin.Context) { } err = h.repo.DeleteDeploy(ctx.Request.Context(), delReq) if err != nil { - slog.Error("Failed to delete deploy", slog.Any("error", err)) + if errors.Is(err, component.ErrForbidden) { + slog.Error("not allowed to delete finetune", slog.Any("error", err), slog.Any("req", delReq)) + httpbase.ForbiddenError(ctx, err) + return + } + slog.Error("Failed to delete finetune", slog.Any("error", err), slog.Any("req", delReq)) httpbase.ServerError(ctx, err) return } @@ -868,7 +902,12 @@ func (h *ModelHandler) DeployStop(ctx *gin.Context) { } err = h.repo.DeployStop(ctx.Request.Context(), stopReq) if err != nil { - slog.Error("Failed to stop deploy", slog.Any("error", err)) + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to stop inference", slog.Any("error", err), slog.Any("req", stopReq)) + httpbase.ForbiddenError(ctx, err) + return + } + slog.Error("Failed to stop inference", slog.Any("error", err), slog.Any("req", stopReq)) httpbase.ServerError(ctx, err) return } @@ -924,7 +963,12 @@ func (h *ModelHandler) DeployStart(ctx *gin.Context) { err = h.repo.DeployStart(ctx.Request.Context(), startReq) if err != nil { - slog.Error("Failed to start deploy", slog.Any("error", err), slog.Any("repoType", types.ModelRepo), slog.String("namespace", namespace), slog.String("name", name), slog.Any("deployID", id)) + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to start inference", slog.Any("error", err), slog.Any("req", startReq)) + httpbase.ForbiddenError(ctx, err) + return + } + slog.Error("Failed to start inference", slog.Any("error", err), slog.Any("req", startReq)) httpbase.ServerError(ctx, err) return } @@ -1037,7 +1081,12 @@ func (h *ModelHandler) FinetuneStop(ctx *gin.Context) { } err = h.repo.DeployStop(ctx.Request.Context(), stopReq) if err != nil { - slog.Error("Failed to stop deploy", slog.Any("error", err)) + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to stop finetune", slog.Any("req", stopReq), slog.Any("error", err)) + httpbase.ForbiddenError(ctx, err) + return + } + slog.Error("Failed to stop finetune", slog.Any("req", stopReq), slog.Any("error", err)) httpbase.ServerError(ctx, err) return } @@ -1091,7 +1140,12 @@ func (h *ModelHandler) FinetuneStart(ctx *gin.Context) { } err = h.repo.DeployStart(ctx.Request.Context(), startReq) if err != nil { - slog.Error("Failed to start deploy", slog.Any("error", err), slog.Any("repoType", types.ModelRepo), slog.String("namespace", namespace), slog.String("name", name), slog.Any("deployID", id)) + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to start finetune", slog.Any("error", err), slog.Any("req", startReq)) + httpbase.ForbiddenError(ctx, err) + return + } + slog.Error("Failed to start finetune", slog.Any("error", err), slog.Any("req", startReq)) httpbase.ServerError(ctx, err) return } @@ -1319,11 +1373,12 @@ func (h *ModelHandler) AllFiles(ctx *gin.Context) { req.CurrentUser = httpbase.GetCurrentUser(ctx) detail, err := h.repo.AllFiles(ctx.Request.Context(), req) if err != nil { - if errors.Is(err, component.ErrUnauthorized) { - httpbase.UnauthorizedError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to get model all files", slog.Any("error", err), slog.Any("req", req)) + httpbase.ForbiddenError(ctx, err) return } - slog.Error("Failed to get model all files", slog.Any("error", err)) + slog.Error("Failed to get model all files", slog.Any("error", err), slog.Any("req", req)) httpbase.ServerError(ctx, err) return } @@ -1386,14 +1441,19 @@ func (h *ModelHandler) DeployServerless(ctx *gin.Context) { req.SecureLevel = 1 // public for serverless deployID, err := h.model.Deploy(ctx.Request.Context(), deployReq, req) if err != nil { - slog.Error("failed to deploy model as serverless", slog.String("namespace", namespace), - slog.String("name", name), slog.Any("currentUser", currentUser), slog.Any("req", req), slog.Any("error", err)) + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to deploy model as serverless", slog.Any("error", err), slog.Any("deploy_req", deployReq)) + + httpbase.ForbiddenError(ctx, err) + return + } + slog.Error("failed to deploy model as serverless", slog.Any("deploy_req", deployReq), slog.Any("run_req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) return } - slog.Debug("deploy model as serverless created", slog.String("namespace", namespace), - slog.String("name", name), slog.Int64("deploy_id", deployID)) + slog.Info("deploy model as serverless created", slog.String("namespace", namespace), + slog.String("name", name), slog.Int64("deploy_id", deployID), slog.String("current_user", currentUser)) // return deploy_id response := types.DeployRepo{DeployID: deployID} @@ -1450,7 +1510,13 @@ func (h *ModelHandler) ServerlessStart(ctx *gin.Context) { err = h.repo.DeployStart(ctx.Request.Context(), startReq) if err != nil { - slog.Error("Failed to start deploy", slog.Any("error", err), slog.Any("repoType", types.ModelRepo), slog.String("namespace", namespace), slog.String("name", name), slog.Any("deployID", id)) + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to start model serverless deploy", slog.Any("error", err), slog.Any("req", startReq)) + + httpbase.ForbiddenError(ctx, err) + return + } + slog.Info("failed to start model serverless deploy", slog.Any("error", err), slog.Any("req", startReq)) httpbase.ServerError(ctx, err) return } diff --git a/api/handler/repo.go b/api/handler/repo.go index f8960963..390590d6 100644 --- a/api/handler/repo.go +++ b/api/handler/repo.go @@ -1580,13 +1580,14 @@ func (h *RepoHandler) DeployDetail(ctx *gin.Context) { response, err := h.c.DeployDetail(ctx.Request.Context(), detailReq) if err != nil { - slog.Error("fail to deploy detail", slog.String("error", err.Error()), slog.Any("repotype", repoType), slog.Any("namespace", namespace), slog.Any("name", name), slog.Any("deploy id", deployID)) - var pErr *types.PermissionError - if errors.As(err, &pErr) { - httpbase.UnauthorizedError(ctx, err) - } else { - httpbase.ServerError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to get deploy detail", slog.Any("error", err), slog.Any("req", detailReq)) + httpbase.ForbiddenError(ctx, err) + return } + + slog.Error("failed to get deploy detail", slog.Any("error", err), slog.Any("req", detailReq)) + httpbase.ServerError(ctx, err) return } @@ -1656,13 +1657,14 @@ func (h *RepoHandler) DeployInstanceLogs(ctx *gin.Context) { // user http request context instead of gin context, so that server knows the life cycle of the request logReader, err := h.c.DeployInstanceLogs(ctx.Request.Context(), logReq) if err != nil { - var pErr *types.PermissionError - if errors.As(err, &pErr) { - httpbase.UnauthorizedError(ctx, err) - } else { - slog.Error("Failed to get deploy instance logs", slog.Any("logReq", logReq), slog.Any("error", err)) - httpbase.ServerError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to get instance logs", slog.Any("error", err), slog.Any("req", logReq)) + httpbase.ForbiddenError(ctx, err) + return } + + slog.Error("failed to get instance logs", slog.Any("error", err), slog.Any("req", logReq)) + httpbase.ServerError(ctx, err) return } @@ -1775,19 +1777,22 @@ func (h *RepoHandler) DeployStatus(ctx *gin.Context) { allow, err := h.c.AllowAccessDeploy(ctx.Request.Context(), statusReq) if err != nil { - slog.Error("failed to check user permission", "error", err) - var pErr *types.PermissionError - if errors.As(err, &pErr) { - httpbase.UnauthorizedError(ctx, err) - } else { - httpbase.ServerError(ctx, err) + + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to get deploy status", slog.Any("error", err), slog.Any("req", statusReq)) + httpbase.ForbiddenError(ctx, err) + return } + + slog.Error("failed to get deploy status", "error", err, "req", statusReq) + httpbase.ServerError(ctx, err) return } if !allow { - slog.Info("user not allowed to query deploy status", slog.String("namespace", namespace), - slog.String("name", name), slog.Any("username", currentUser), slog.Any("deploy_id", deployID)) + slog.Info("not allowed to query deploy status", "req", statusReq) + httpbase.ForbiddenError(ctx, err) + return } ctx.Writer.Header().Set("Content-Type", "text/event-stream") @@ -1852,6 +1857,12 @@ func (h *RepoHandler) SyncMirror(ctx *gin.Context) { } err = h.c.SyncMirror(ctx.Request.Context(), repoType, namespace, name, currentUser) if err != nil { + if errors.Is(err, component.ErrForbidden) { + slog.Info("not allowed to sync mirror", slog.Any("error", err), slog.String("repo_type", string(repoType)), slog.String("path", fmt.Sprintf("%s/%s", namespace, name))) + httpbase.ForbiddenError(ctx, err) + return + } + slog.Error("Failed to sync mirror for", slog.String("repo_type", string(repoType)), slog.String("path", fmt.Sprintf("%s/%s", namespace, name)), "error", err) httpbase.ServerError(ctx, err) return @@ -1928,7 +1939,7 @@ func (h *RepoHandler) DeployUpdate(ctx *gin.Context) { if !allow { slog.Info("user not allowed to update deploy", slog.String("namespace", namespace), slog.String("name", name), slog.Any("username", currentUser)) - httpbase.UnauthorizedError(ctx, errors.New("user not allowed to update deploy")) + httpbase.ForbiddenError(ctx, errors.New("user not allowed to update deploy")) return } @@ -1965,6 +1976,12 @@ func (h *RepoHandler) DeployUpdate(ctx *gin.Context) { } err = h.c.DeployUpdate(ctx.Request.Context(), updateReq, req) if err != nil { + if errors.Is(err, component.ErrForbidden) { + slog.Info("user not allowed to update deploy", slog.String("namespace", namespace), + slog.String("name", name), slog.Any("username", currentUser), slog.Int64("deploy_id", deployID)) + httpbase.ForbiddenError(ctx, err) + return + } slog.Error("failed to update deploy", slog.String("namespace", namespace), slog.String("name", name), slog.Any("username", currentUser), slog.Int64("deploy_id", deployID), slog.Any("error", err)) httpbase.ServerError(ctx, fmt.Errorf("failed to update deploy, %w", err)) return @@ -2068,7 +2085,14 @@ func (h *RepoHandler) ServerlessDetail(ctx *gin.Context) { response, err := h.c.DeployDetail(ctx.Request.Context(), detailReq) if err != nil { - slog.Error("fail to serverless detail", slog.String("error", err.Error()), slog.Any("namespace", namespace), slog.Any("name", name), slog.Any("deploy id", deployID)) + if errors.Is(err, component.ErrForbidden) { + slog.Info("user not allowed to get serverless deploy detail", slog.String("namespace", namespace), + slog.String("name", name), slog.Any("username", currentUser), slog.Int64("deploy_id", deployID)) + httpbase.ForbiddenError(ctx, err) + return + } + + slog.Error("fail to get serverless deploy detail", slog.String("error", err.Error()), slog.Any("namespace", namespace), slog.Any("name", name), slog.Any("deploy id", deployID)) httpbase.ServerError(ctx, err) return } @@ -2138,13 +2162,14 @@ func (h *RepoHandler) ServerlessLogs(ctx *gin.Context) { // user http request context instead of gin context, so that server knows the life cycle of the request logReader, err := h.c.DeployInstanceLogs(ctx.Request.Context(), logReq) if err != nil { - var pErr *types.PermissionError - if errors.As(err, &pErr) { - httpbase.UnauthorizedError(ctx, err) - } else { - slog.Error("Failed to get deploy instance logs", slog.Any("logReq", logReq), slog.Any("error", err)) - httpbase.ServerError(ctx, err) + if errors.Is(err, component.ErrForbidden) { + slog.Info("user not allowed to get serverless deploy logs", slog.Any("logReq", logReq), slog.Any("error", err)) + httpbase.ForbiddenError(ctx, err) + return } + + slog.Error("Failed to get serverless deploy logs", slog.Any("logReq", logReq), slog.Any("error", err)) + httpbase.ServerError(ctx, err) return } @@ -2221,14 +2246,21 @@ func (h *RepoHandler) ServerlessStatus(ctx *gin.Context) { allow, err := h.c.AllowAccessDeploy(ctx.Request.Context(), statusReq) if err != nil { - slog.Error("failed to check user permission", slog.Any("error", err)) + if errors.Is(err, component.ErrForbidden) { + slog.Info("user not allowed to get serverless deploy status", slog.Any("error", err), slog.Any("req", statusReq)) + httpbase.ForbiddenError(ctx, err) + return + } + + slog.Error("failed to check user permission", slog.Any("error", err), slog.Any("req", statusReq)) httpbase.ServerError(ctx, fmt.Errorf("failed to check user permission, %w", err)) return } if !allow { - slog.Info("user not allowed to query deploy status", slog.String("namespace", namespace), - slog.String("name", name), slog.Any("username", currentUser), slog.Any("deploy_id", deployID)) + slog.Info("user not allowed to query deploy status", slog.Any("req", statusReq)) + httpbase.ForbiddenError(ctx, errors.New("user not allowed to query serverless deploy status")) + return } ctx.Writer.Header().Set("Content-Type", "text/event-stream") @@ -2326,7 +2358,13 @@ func (h *RepoHandler) ServerlessUpdate(ctx *gin.Context) { } err = h.c.DeployUpdate(ctx.Request.Context(), updateReq, req) if err != nil { - slog.Error("failed to update serverless", slog.String("namespace", namespace), slog.String("name", name), slog.Any("username", currentUser), slog.Int64("deploy_id", deployID), slog.Any("error", err)) + if errors.Is(err, component.ErrForbidden) { + slog.Info("user not allowed to update serverless", slog.Any("error", err), slog.Any("req", updateReq)) + httpbase.ForbiddenError(ctx, err) + return + } + + slog.Error("failed to update serverless", slog.Any("error", err), slog.Any("req", updateReq)) httpbase.ServerError(ctx, fmt.Errorf("failed to update serverless, %w", err)) return } diff --git a/api/handler/repo_test.go b/api/handler/repo_test.go index e8970597..f33aa749 100644 --- a/api/handler/repo_test.go +++ b/api/handler/repo_test.go @@ -959,7 +959,7 @@ func TestRepoHandler_DeployUpdate(t *testing.T) { tester.mocks.repo.EXPECT().AllowAdminAccess(tester.ctx, types.ModelRepo, "u", "r", "u").Return(false, nil) tester.Execute() tester.ResponseEq( - t, 401, "user not allowed to update deploy", nil, + t, 403, "user not allowed to update deploy", nil, ) }) diff --git a/common/types/repo.go b/common/types/repo.go index ba387aac..3e6f700d 100644 --- a/common/types/repo.go +++ b/common/types/repo.go @@ -195,12 +195,3 @@ type ScanReq struct { ArchMap map[string]string Models []string } - -type PermissionError struct { - Message string -} - -// Add the Error() method to PermissionError. -func (e *PermissionError) Error() string { - return e.Message // Return the message field as the error description. -} diff --git a/component/code.go b/component/code.go index 4db155e2..18792ee3 100644 --- a/component/code.go +++ b/component/code.go @@ -295,7 +295,7 @@ func (c *codeComponentImpl) Show(ctx context.Context, namespace, name, currentUs return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, ErrUnauthorized + return nil, ErrForbidden } ns, err := c.repoComponent.GetNameSpaceInfo(ctx, namespace) @@ -364,7 +364,7 @@ func (c *codeComponentImpl) Relations(ctx context.Context, namespace, name, curr allow, _ := c.repoComponent.AllowReadAccessRepo(ctx, code.Repository, currentUser) if !allow { - return nil, ErrUnauthorized + return nil, ErrForbidden } return c.getRelations(ctx, code.RepositoryID, currentUser) diff --git a/component/dataset.go b/component/dataset.go index 875a2329..8c19ce02 100644 --- a/component/dataset.go +++ b/component/dataset.go @@ -417,7 +417,7 @@ func (c *datasetComponentImpl) Show(ctx context.Context, namespace, name, curren return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, ErrUnauthorized + return nil, ErrForbidden } ns, err := c.repoComponent.GetNameSpaceInfo(ctx, namespace) @@ -488,7 +488,7 @@ func (c *datasetComponentImpl) Relations(ctx context.Context, namespace, name, c allow, _ := c.repoComponent.AllowReadAccessRepo(ctx, dataset.Repository, currentUser) if !allow { - return nil, ErrUnauthorized + return nil, ErrForbidden } return c.getRelations(ctx, dataset.RepositoryID, currentUser) diff --git a/component/errors.go b/component/errors.go index faafe384..e44d12a4 100644 --- a/component/errors.go +++ b/component/errors.go @@ -1,12 +1,22 @@ package component -import "errors" +import ( + "errors" + "fmt" +) var ( - ErrUnauthorized = errors.New("unauthorized") - ErrNotFound = errors.New("not found") + // not allowed for anoymous user (need to login first) + ErrUnauthorized = errors.New("unauthorized") + ErrNotFound = errors.New("not found") + // not enough permission for current user ErrForbidden = errors.New("forbidden") ErrUserNotFound = errors.New("user not found, please login first") ErrAlreadyExists = errors.New("the record already exists") ErrPermissionDenied = errors.New("permission denied") ) + +// ErrForbiddenMsg returns a new ErrForbidden with extra message +func ErrForbiddenMsg(msg string) error { + return fmt.Errorf("%s, %w", msg, ErrForbidden) +} diff --git a/component/model.go b/component/model.go index c0e3cd70..d9e7d2d8 100644 --- a/component/model.go +++ b/component/model.go @@ -408,7 +408,7 @@ func (c *modelComponentImpl) Show(ctx context.Context, namespace, name, currentU return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, ErrUnauthorized + return nil, ErrForbidden } ns, err := c.repoComponent.GetNameSpaceInfo(ctx, namespace) @@ -906,7 +906,7 @@ func (c *modelComponentImpl) Deploy(ctx context.Context, deployReq types.DeployA // Check if the user is an admin isAdmin := c.repoComponent.IsAdminRole(user) if !isAdmin { - return -1, fmt.Errorf("need admin permission for Serverless deploy") + return -1, ErrForbiddenMsg("need admin permission for Serverless deploy") } } diff --git a/component/repo.go b/component/repo.go index ce101059..e58f9f1b 100644 --- a/component/repo.go +++ b/component/repo.go @@ -256,11 +256,11 @@ func (c *repoComponentImpl) CreateRepo(ctx context.Context, req types.CreateRepo return nil, nil, err } if !canWrite { - return nil, nil, fmt.Errorf("users do not have permission to create %s in this organization", req.RepoType) + return nil, nil, ErrForbiddenMsg("users do not have permission to create repo in this organization") } } else { if namespace.Path != user.Username { - return nil, nil, fmt.Errorf("users do not have permission to create %s in this namespace", req.RepoType) + return nil, nil, ErrForbiddenMsg("users do not have permission to create repo in this namespace") } } } @@ -332,11 +332,11 @@ func (c *repoComponentImpl) UpdateRepo(ctx context.Context, req types.UpdateRepo return nil, err } if !canWrite { - return nil, errors.New("users do not have permission to update repo in this organization") + return nil, ErrForbiddenMsg("users do not have permission to update repo in this organization") } } else { if namespace.Path != user.Username { - return nil, errors.New("users do not have permission to update repo in this namespace") + return nil, ErrForbiddenMsg("users do not have permission to update repo in this namespace") } } } @@ -397,11 +397,11 @@ func (c *repoComponentImpl) DeleteRepo(ctx context.Context, req types.DeleteRepo return nil, err } if !canWrite { - return nil, errors.New("users do not have permission to delete repo in this organization") + return nil, ErrForbiddenMsg("users do not have permission to delete repo in this organization") } } else { if namespace.Path != user.Username { - return nil, errors.New("users do not have permission to delete repo in this namespace") + return nil, ErrForbiddenMsg("users do not have permission to delete repo in this namespace") } } @@ -663,7 +663,7 @@ func (c *repoComponentImpl) UpdateFile(ctx context.Context, req *types.UpdateFil return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanWrite { - return nil, ErrUnauthorized + return nil, ErrForbiddenMsg("users do not have permission to update file in this repo") } user, err = c.userStore.FindByUsername(ctx, req.Username) @@ -749,7 +749,7 @@ func (c *repoComponentImpl) DeleteFile(ctx context.Context, req *types.DeleteFil return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanWrite { - return nil, ErrUnauthorized + return nil, ErrForbiddenMsg("users do not have permission to delete file in this repo") } user, err = c.userStore.FindByUsername(ctx, req.Username) @@ -889,7 +889,7 @@ func (c *repoComponentImpl) LastCommit(ctx context.Context, req *types.GetCommit return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, ErrForbidden + return nil, ErrForbiddenMsg("users do not have permission to get last commit in this repo") } if req.Ref == "" { @@ -919,7 +919,7 @@ func (c *repoComponentImpl) FileRaw(ctx context.Context, req *types.GetFileReq) return "", fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return "", ErrUnauthorized + return "", ErrForbiddenMsg("users do not have permission to get file raw in this repo") } if repo.Source != types.LocalSource && strings.ToLower(req.Path) == "readme.md" { @@ -963,7 +963,7 @@ func (c *repoComponentImpl) DownloadFile(ctx context.Context, req *types.GetFile return nil, 0, "", fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, 0, "", ErrUnauthorized + return nil, 0, "", ErrForbiddenMsg("users do not have permission to download file in this repo") } err = c.repoStore.UpdateRepoFileDownloads(ctx, repo, time.Now(), 1) @@ -1013,7 +1013,7 @@ func (c *repoComponentImpl) Branches(ctx context.Context, req *types.GetBranches return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, ErrUnauthorized + return nil, ErrForbiddenMsg("users do not have permission to get branches in this repo") } getBranchesReq := gitserver.GetBranchesReq{ @@ -1044,7 +1044,7 @@ func (c *repoComponentImpl) Tags(ctx context.Context, req *types.GetTagsReq) ([] return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, ErrUnauthorized + return nil, ErrForbiddenMsg("users do not have permission to get tags in this repo") } tags, err := c.repoStore.Tags(ctx, repo.ID) @@ -1065,7 +1065,7 @@ func (c *repoComponentImpl) UpdateTags(ctx context.Context, namespace, name stri return fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanWrite { - return ErrUnauthorized + return ErrForbiddenMsg("users do not have permission to update tags in this repo") } tagScope := getTagScopeByRepoType(repoType) @@ -1089,7 +1089,7 @@ func (c *repoComponentImpl) Tree(ctx context.Context, req *types.GetFileReq) ([] return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, ErrForbidden + return nil, ErrForbiddenMsg("users do not have permission to get tree in this repo") } if repo.Source != types.LocalSource { @@ -1154,7 +1154,7 @@ func (c *repoComponentImpl) TreeV2(ctx context.Context, req *types.GetTreeReques return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, ErrForbidden + return nil, ErrForbiddenMsg("users do not have permission to get tree in this repo") } if req.Limit == 0 { @@ -1241,7 +1241,7 @@ func (c *repoComponentImpl) LogsTree(ctx context.Context, req *types.GetLogsTree return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, ErrForbidden + return nil, ErrForbiddenMsg("users do not have permission to get logs tree in this repo") } if req.Limit == 0 { @@ -1338,10 +1338,10 @@ func (c *repoComponentImpl) SDKListFiles(ctx context.Context, repoType types.Rep canRead, err := c.AllowReadAccessRepo(ctx, repo, userName) if err != nil { - return nil, ErrUnauthorized + return nil, err } if !canRead { - return nil, ErrUnauthorized + return nil, ErrForbiddenMsg("users do not have permission to access this repo") } if ref == "" { @@ -1397,7 +1397,7 @@ func (c *repoComponentImpl) HeadDownloadFile(ctx context.Context, req *types.Get return nil, err } if !canRead { - return nil, ErrUnauthorized + return nil, ErrForbiddenMsg("users do not have permission to download file in this repo") } if req.Ref == "" { req.Ref = repo.DefaultBranch @@ -1432,7 +1432,7 @@ func (c *repoComponentImpl) SDKDownloadFile(ctx context.Context, req *types.GetF return nil, 0, "", err } if !canRead { - return nil, 0, "", ErrUnauthorized + return nil, 0, "", ErrForbiddenMsg("users do not have permission to download file in this repo") } if req.Ref == "" { req.Ref = repo.DefaultBranch @@ -1522,7 +1522,7 @@ func (c *repoComponentImpl) FileInfo(ctx context.Context, req *types.GetFileReq) return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, ErrUnauthorized + return nil, ErrForbiddenMsg("users do not have permission to get file info in this repo") } if req.Ref == "" { @@ -1718,7 +1718,7 @@ func (c *repoComponentImpl) GetCommitWithDiff(ctx context.Context, req *types.Ge return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } if !permission.CanRead { - return nil, ErrUnauthorized + return nil, ErrForbiddenMsg("users do not have permission to get commit in this repo") } getCommitReq := gitserver.GetRepoLastCommitReq{ Namespace: req.Namespace, // user name or org name @@ -2440,10 +2440,10 @@ func (c *repoComponentImpl) AllowAccessDeploy(ctx context.Context, req types.Dep } deploy, err := c.deployTaskStore.GetDeployByID(ctx, req.DeployID) if err != nil { - return false, err + return false, fmt.Errorf("fail to get deploy by ID: %v, %w", req.DeployID, err) } if deploy == nil { - return false, fmt.Errorf("fail to get deploy by ID: %v", req.DeployID) + return false, fmt.Errorf("deploy not found by ID: %v", req.DeployID) } if req.DeployType == types.ServerlessType { return c.checkAccessDeployForServerless(ctx, repo.ID, req.CurrentUser, deploy) @@ -2460,7 +2460,7 @@ func (c *repoComponentImpl) checkAccessDeployForUser(ctx context.Context, repoID } if deploy.UserID != user.ID { // deny access due to deploy was not created by - return false, &types.PermissionError{Message: "deploy was not created by user"} + return false, ErrForbiddenMsg("deploy was not created by user") } if deploy.RepoID != repoID { // deny access for invalid repo @@ -2476,7 +2476,7 @@ func (c *repoComponentImpl) checkAccessDeployForServerless(ctx context.Context, } isAdmin := c.IsAdminRole(user) if !isAdmin { - return false, errors.New("need admin permission to see Serverless deploy instances") + return false, ErrForbiddenMsg("need admin permission to see Serverless deploy instances") } if deploy.RepoID != repoID { // deny access for invalid repo @@ -2600,7 +2600,7 @@ func (c *repoComponentImpl) SyncMirror(ctx context.Context, repoType types.Repos } if !admin { - return fmt.Errorf("users do not have permission to delete mirror for this repo") + return ErrForbiddenMsg("need be owner or admin role to sync mirror for this repo") } repo, err := c.repoStore.FindByPath(ctx, repoType, namespace, name) if err != nil { @@ -2638,7 +2638,7 @@ func (c *repoComponentImpl) SyncMirror(ctx context.Context, repoType types.Repos func (c *repoComponentImpl) checkDeployPermissionForUser(ctx context.Context, deployReq types.DeployActReq) (*database.User, *database.Deploy, error) { user, err := c.userStore.FindByUsername(ctx, deployReq.CurrentUser) if err != nil { - return nil, nil, &types.PermissionError{Message: "user does not exist"} + return nil, nil, fmt.Errorf("deploy permission check user failed, %w", err) } deploy, err := c.deployTaskStore.GetDeployByID(ctx, deployReq.DeployID) if err != nil { @@ -2648,7 +2648,7 @@ func (c *repoComponentImpl) checkDeployPermissionForUser(ctx context.Context, de return nil, nil, fmt.Errorf("do not found user deploy %v", deployReq.DeployID) } if deploy.UserID != user.ID { - return nil, nil, &types.PermissionError{Message: "deploy was not created by user"} + return nil, nil, ErrForbiddenMsg("deploy was not created by user") } return &user, deploy, nil } @@ -2656,11 +2656,11 @@ func (c *repoComponentImpl) checkDeployPermissionForUser(ctx context.Context, de func (c *repoComponentImpl) checkDeployPermissionForServerless(ctx context.Context, deployReq types.DeployActReq) (*database.User, *database.Deploy, error) { user, err := c.userStore.FindByUsername(ctx, deployReq.CurrentUser) if err != nil { - return nil, nil, fmt.Errorf("user does not exist, %w", err) + return nil, nil, fmt.Errorf("deploy permission check user failed, %w", err) } isAdmin := c.IsAdminRole(user) if !isAdmin { - return nil, nil, fmt.Errorf("need admin permission for Serverless deploy") + return nil, nil, ErrForbiddenMsg("need admin permission for Serverless deploy") } deploy, err := c.deployTaskStore.GetDeployByID(ctx, deployReq.DeployID) if err != nil { @@ -2785,7 +2785,7 @@ func (c *repoComponentImpl) AllFiles(ctx context.Context, req types.GetAllFilesR } if !read { - return nil, fmt.Errorf("users do not have permission to get all files for this repo") + return nil, ErrForbiddenMsg("users do not have permission to get all files for this repo") } } allFiles, err := getAllFiles(req.Namespace, req.Name, "", req.RepoType, req.Ref, c.git.GetRepoFileTree) diff --git a/component/repo_test.go b/component/repo_test.go index c6087090..1974a8e0 100644 --- a/component/repo_test.go +++ b/component/repo_test.go @@ -587,7 +587,7 @@ func TestRepoComponent_FileRaw(t *testing.T) { Path: c.path, CurrentUser: currentUser, }) - require.Equal(t, ErrUnauthorized, err) + require.True(t, errors.Is(err, ErrForbidden)) return } @@ -1612,7 +1612,7 @@ func TestRepoComponent_LastCommit(t *testing.T) { actualCommit, err := repoComp.LastCommit(context.Background(), &types.GetCommitsReq{}) require.Nil(t, actualCommit) - require.Equal(t, err, ErrForbidden) + require.ErrorIs(t, err, ErrForbidden) }) } @@ -1666,7 +1666,7 @@ func TestRepoComponent_Tree(t *testing.T) { actualTree, err := repoComp.Tree(context.Background(), &types.GetFileReq{}) require.Nil(t, actualTree) - require.Equal(t, err, ErrForbidden) + require.ErrorIs(t, err, ErrForbidden) }) } @@ -1866,7 +1866,7 @@ func TestRepoComponent_TreeV2(t *testing.T) { actualTree, err := repoComp.TreeV2(context.Background(), &types.GetTreeRequest{}) require.Nil(t, actualTree) - require.Equal(t, err, ErrForbidden) + require.ErrorIs(t, err, ErrForbidden) }) } @@ -1979,7 +1979,7 @@ func TestRepoComponent_LogsTree(t *testing.T) { actualTree, err := repoComp.LogsTree(context.Background(), &types.GetLogsTreeRequest{}) require.Nil(t, actualTree) - require.Equal(t, err, ErrForbidden) + require.ErrorIs(t, err, ErrForbidden) }) }