Skip to content

Commit

Permalink
[UT] add runtime framework/arch and userlike UT code (#192)
Browse files Browse the repository at this point in the history
Co-authored-by: Haihui.Wang <[email protected]>
  • Loading branch information
SeanHH86 and wanghh2000 authored Nov 26, 2024
1 parent 41e810e commit c51fbd6
Show file tree
Hide file tree
Showing 6 changed files with 434 additions and 0 deletions.
6 changes: 6 additions & 0 deletions builder/store/database/runtime_architecture.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ func NewRuntimeArchitecturesStore() RuntimeArchitecturesStore {
}
}

func NewRuntimeArchitecturesStoreWithDB(db *DB) RuntimeArchitecturesStore {
return &runtimeArchitecturesStoreImpl{
db: db,
}
}

type RuntimeArchitecture struct {
ID int64 `bun:",pk,autoincrement" json:"id"`
RuntimeFrameworkID int64 `bun:",notnull" json:"runtime_framework_id"`
Expand Down
102 changes: 102 additions & 0 deletions builder/store/database/runtime_architecture_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package database_test

import (
"context"
"testing"

"github.com/stretchr/testify/require"
"opencsg.com/csghub-server/builder/store/database"
"opencsg.com/csghub-server/common/tests"
)

func TestRuntimeArchitecturesStore_AddAndListByRuntimeFrameworkID(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

raStore := database.NewRuntimeArchitecturesStoreWithDB(db)

err := raStore.Add(ctx, database.RuntimeArchitecture{
RuntimeFrameworkID: 1,
ArchitectureName: "Qwen2ForCausalLM",
})
require.Nil(t, err)

res, err := raStore.ListByRuntimeFrameworkID(ctx, 1)
require.Nil(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "Qwen2ForCausalLM", res[0].ArchitectureName)
}

func TestRuntimeArchitecturesStore_DeleteByRuntimeIDAndArchName(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

raStore := database.NewRuntimeArchitecturesStoreWithDB(db)
err := raStore.Add(ctx, database.RuntimeArchitecture{
RuntimeFrameworkID: 1,
ArchitectureName: "Qwen2ForCausalLM",
})
require.Nil(t, err)

err = raStore.DeleteByRuntimeIDAndArchName(ctx, 1, "Qwen2ForCausalLM")
require.Nil(t, err)

arch, err := raStore.FindByRuntimeIDAndArchName(ctx, 1, "Qwen2ForCausalLM")
require.Equal(t, nil, err)
require.Nil(t, nil, arch)
}

func TestRuntimeArchitecturesStore_FindByRuntimeIDAndArchName(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

raStore := database.NewRuntimeArchitecturesStoreWithDB(db)
err := raStore.Add(ctx, database.RuntimeArchitecture{
RuntimeFrameworkID: 1,
ArchitectureName: "Qwen2ForCausalLM",
})
require.Nil(t, err)

res, err := raStore.FindByRuntimeIDAndArchName(ctx, 1, "Qwen2ForCausalLM")
require.Nil(t, err)
require.Equal(t, "Qwen2ForCausalLM", res.ArchitectureName)
require.Equal(t, int64(1), res.RuntimeFrameworkID)
}

func TestRuntimeArchitecturesStore_ListByRArchName(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

raStore := database.NewRuntimeArchitecturesStoreWithDB(db)
err := raStore.Add(ctx, database.RuntimeArchitecture{
RuntimeFrameworkID: 1,
ArchitectureName: "Qwen2ForCausalLM",
})
require.Nil(t, err)

res, err := raStore.ListByRArchName(ctx, "Qwen2ForCausalLM")
require.Nil(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "Qwen2ForCausalLM", res[0].ArchitectureName)
}

func TestRuntimeArchitecturesStore_ListByRArchNameAndModel(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

raStore := database.NewRuntimeArchitecturesStoreWithDB(db)
err := raStore.Add(ctx, database.RuntimeArchitecture{
RuntimeFrameworkID: 1,
ArchitectureName: "Qwen2ForCausalLM",
})
require.Nil(t, err)

res, err := raStore.ListByRArchNameAndModel(ctx, "Qwen2ForCausalLM", "model1")
require.Nil(t, err)
require.Equal(t, 1, len(res))
}
6 changes: 6 additions & 0 deletions builder/store/database/runtime_framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ func NewRuntimeFrameworksStore() RuntimeFrameworksStore {
}
}

func NewRuntimeFrameworksStoreWithDB(db *DB) RuntimeFrameworksStore {
return &runtimeFrameworksStoreImpl{
db: db,
}
}

type RuntimeFramework struct {
ID int64 `bun:",pk,autoincrement" json:"id"`
FrameName string `bun:",notnull" json:"frame_name"`
Expand Down
145 changes: 145 additions & 0 deletions builder/store/database/runtime_framework_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package database_test

import (
"context"
"database/sql"
"testing"

"github.com/stretchr/testify/require"
"opencsg.com/csghub-server/builder/store/database"
"opencsg.com/csghub-server/common/tests"
)

func TestRuntimeFrameworksStore_AddAndUpdateAndList(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

rfStore := database.NewRuntimeFrameworksStoreWithDB(db)

err := rfStore.Add(ctx, database.RuntimeFramework{
ID: 1,
FrameName: "vllm",
Type: 1,
})
require.Nil(t, err)

rf, err := rfStore.Update(ctx, database.RuntimeFramework{
ID: 1,
FrameName: "vllm",
Type: 2,
})
require.Nil(t, err)
require.Equal(t, 2, rf.Type)

res, err := rfStore.List(ctx, 2)
require.Nil(t, err)
require.Equal(t, 1, len(res))
}

func TestRuntimeFrameworksStore_FindByIDAndDelete(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

rfStore := database.NewRuntimeFrameworksStoreWithDB(db)

err := rfStore.Add(ctx, database.RuntimeFramework{
ID: 1,
FrameName: "vllm",
Type: 1,
})
require.Nil(t, err)

res, err := rfStore.FindByID(ctx, 1)
require.Nil(t, err)
require.Equal(t, "vllm", res.FrameName)

err = rfStore.Delete(ctx, database.RuntimeFramework{
ID: 1,
FrameName: "vllm",
Type: 1,
})
require.Nil(t, err)

_, err = rfStore.FindByID(ctx, 1)
require.Equal(t, sql.ErrNoRows, err)
}

func TestRuntimeFrameworksStore_FindEnabledByID(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

rfStore := database.NewRuntimeFrameworksStoreWithDB(db)
err := rfStore.Add(ctx, database.RuntimeFramework{
ID: 1,
FrameName: "vllm",
Type: 1,
Enabled: 1,
})
require.Nil(t, err)

res, err := rfStore.FindEnabledByID(ctx, 1)
require.Nil(t, err)
require.Equal(t, "vllm", res.FrameName)
}

func TestRuntimeFrameworksStore_FindEnabledByName(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

rfStore := database.NewRuntimeFrameworksStoreWithDB(db)
err := rfStore.Add(ctx, database.RuntimeFramework{
ID: 1,
FrameName: "vllm",
Type: 1,
Enabled: 1,
})
require.Nil(t, err)

res, err := rfStore.FindEnabledByName(ctx, "vllm")
require.Nil(t, err)
require.Equal(t, "vllm", res.FrameName)
}

func TestRuntimeFrameworksStore_ListAll(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

rfStore := database.NewRuntimeFrameworksStoreWithDB(db)
err := rfStore.Add(ctx, database.RuntimeFramework{
ID: 1,
FrameName: "vllm",
Type: 1,
Enabled: 1,
})
require.Nil(t, err)

res, err := rfStore.ListAll(ctx)
require.Nil(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "vllm", res[0].FrameName)
}

func TestRuntimeFrameworksStore_ListByIDs(t *testing.T) {
db := tests.InitTestDB()
defer db.Close()
ctx := context.TODO()

rfStore := database.NewRuntimeFrameworksStoreWithDB(db)
err := rfStore.Add(ctx, database.RuntimeFramework{
ID: 1,
FrameName: "vllm",
Type: 1,
Enabled: 1,
})
require.Nil(t, err)

res, err := rfStore.ListByIDs(ctx, []int64{1})
require.Nil(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "vllm", res[0].FrameName)
}
6 changes: 6 additions & 0 deletions builder/store/database/user_like.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ func NewUserLikesStore() UserLikesStore {
}
}

func NewUserLikesStoreWithDB(db *DB) UserLikesStore {
return &userLikesStoreImpl{
db: db,
}
}

type UserLike struct {
ID int64 `bun:",pk,autoincrement" json:"id"`
UserID int64 `bun:",notnull" json:"user_id"`
Expand Down
Loading

0 comments on commit c51fbd6

Please sign in to comment.