Skip to content

Commit

Permalink
Refactor runtime framework router (#249)
Browse files Browse the repository at this point in the history
* use admin role to operate runtime framework

* refactor runtime framework router

---------

Co-authored-by: James <[email protected]>
  • Loading branch information
ganisback and James authored Jan 23, 2025
1 parent 9a83786 commit 358b126
Show file tree
Hide file tree
Showing 16 changed files with 31,603 additions and 15,384 deletions.
117 changes: 89 additions & 28 deletions _mocks/opencsg.com/csghub-server/component/mock_ModelComponent.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 11 additions & 10 deletions _mocks/opencsg.com/csghub-server/component/mock_RepoComponent.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 26 additions & 4 deletions api/handler/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -1196,8 +1196,9 @@ func (h *ModelHandler) ListAllRuntimeFramework(ctx *gin.Context) {
// @Param body body types.RuntimeFrameworkModels true "body"
// @Success 200 {object} types.Response{} "OK"
// @Failure 400 {object} types.APIBadRequest "Bad request"
// @Failure 403 {object} types.APIForbidden "Forbidden"
// @Failure 500 {object} types.APIInternalServerError "Internal server error"
// @Router /runtime_framework/{id} [post]
// @Router /runtime_framework/{id}/models [put]
func (h *ModelHandler) UpdateModelRuntimeFrameworks(ctx *gin.Context) {
var req types.RuntimeFrameworkModels
if err := ctx.ShouldBindJSON(&req); err != nil {
Expand All @@ -1206,6 +1207,12 @@ func (h *ModelHandler) UpdateModelRuntimeFrameworks(ctx *gin.Context) {
return
}

currentUser := httpbase.GetCurrentUser(ctx)
if currentUser == "" {
httpbase.UnauthorizedError(ctx, component.ErrUserNotFound)
return
}

id, err := strconv.ParseInt(ctx.Param("id"), 10, 64)
if err != nil {
slog.Error("Bad request format", "error", err)
Expand All @@ -1227,8 +1234,12 @@ func (h *ModelHandler) UpdateModelRuntimeFrameworks(ctx *gin.Context) {

slog.Info("update runtime frameworks models", slog.Any("req", req), slog.Any("runtime framework id", id), slog.Any("deployType", deployType))

list, err := h.model.SetRuntimeFrameworkModes(ctx.Request.Context(), deployType, id, req.Models)
list, err := h.model.SetRuntimeFrameworkModes(ctx.Request.Context(), currentUser, deployType, id, req.Models)
if err != nil {
if errors.Is(err, component.ErrForbidden) {
httpbase.ForbiddenError(ctx, err)
return
}
slog.Error("Failed to set models runtime framework", slog.Any("error", err))
httpbase.ServerError(ctx, err)
return
Expand All @@ -1249,8 +1260,9 @@ func (h *ModelHandler) UpdateModelRuntimeFrameworks(ctx *gin.Context) {
// @Param body body types.RuntimeFrameworkModels true "body"
// @Success 200 {object} types.Response{} "OK"
// @Failure 400 {object} types.APIBadRequest "Bad request"
// @Failure 403 {object} types.APIForbidden "Forbidden"
// @Failure 500 {object} types.APIInternalServerError "Internal server error"
// @Router /runtime_framework/{id} [delete]
// @Router /runtime_framework/{id}/models [delete]
func (h *ModelHandler) DeleteModelRuntimeFrameworks(ctx *gin.Context) {
var req types.RuntimeFrameworkModels
if err := ctx.ShouldBindJSON(&req); err != nil {
Expand All @@ -1259,6 +1271,12 @@ func (h *ModelHandler) DeleteModelRuntimeFrameworks(ctx *gin.Context) {
return
}

currentUser := httpbase.GetCurrentUser(ctx)
if currentUser == "" {
httpbase.UnauthorizedError(ctx, component.ErrUserNotFound)
return
}

id, err := strconv.ParseInt(ctx.Param("id"), 10, 64)
if err != nil {
slog.Error("Bad request format", "error", err)
Expand All @@ -1280,8 +1298,12 @@ func (h *ModelHandler) DeleteModelRuntimeFrameworks(ctx *gin.Context) {

slog.Info("update runtime frameworks models", slog.Any("req", req), slog.Any("runtime framework id", id), slog.Any("deployType", deployType))

list, err := h.model.DeleteRuntimeFrameworkModes(ctx.Request.Context(), deployType, id, req.Models)
list, err := h.model.DeleteRuntimeFrameworkModes(ctx.Request.Context(), currentUser, deployType, id, req.Models)
if err != nil {
if errors.Is(err, component.ErrForbidden) {
httpbase.ForbiddenError(ctx, err)
return
}
slog.Error("Failed to set models runtime framework", slog.Any("error", err))
httpbase.ServerError(ctx, err)
return
Expand Down
4 changes: 2 additions & 2 deletions api/handler/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ func TestModelHandler_UpdateModelRuntimeFramework(t *testing.T) {

tester.WithUser().WithQuery("deploy_type", "").AddPagination(1, 10).WithParam("id", "1")
tester.mocks.model.EXPECT().SetRuntimeFrameworkModes(
tester.ctx, types.InferenceType, int64(1), []string{"foo"},
tester.ctx, "u", types.InferenceType, int64(1), []string{"foo"},
).Return([]string{"bar"}, nil)
tester.WithBody(t, types.RuntimeFrameworkModels{
Models: []string{"foo"},
Expand All @@ -447,7 +447,7 @@ func TestModelHandler_DeleteModelRuntimeFramework(t *testing.T) {

tester.WithUser().WithQuery("deploy_type", "").AddPagination(1, 10).WithParam("id", "1")
tester.mocks.model.EXPECT().DeleteRuntimeFrameworkModes(
tester.ctx, types.InferenceType, int64(1), []string{"foo"},
tester.ctx, "u", types.InferenceType, int64(1), []string{"foo"},
).Return([]string{"bar"}, nil)
tester.WithBody(t, types.RuntimeFrameworkModels{
Models: []string{"foo"},
Expand Down
Loading

0 comments on commit 358b126

Please sign in to comment.