Skip to content

Commit

Permalink
Refactored TestSetExplainer and TestSetLlm to avoid code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertSamoilescu committed Jan 27, 2025
1 parent 1df90d1 commit 5ce0906
Showing 1 changed file with 57 additions and 59 deletions.
116 changes: 57 additions & 59 deletions scheduler/pkg/agent/repository/mlserver/mlserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,49 @@ import (
"github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"
)

func TestSetExplainer(t *testing.T) {
func TestSetModelSettings(t *testing.T) {
g := NewGomegaWithT(t)

envoyHost := "0.0.0.0"
envoyPort := 9000

type test struct {
name string
data []byte
explainerSpec *scheduler.ExplainerSpec
expected *ModelSettings
name string
data []byte
spec interface{}
handler func(*MLServerRepositoryHandler, string, interface{}, string, int) error
assertion func(*GomegaWithT, *ModelSettings)
}

getStrPr := func(str string) *string { return &str }

// define assertion functions
validateExplainer := func(g *GomegaWithT, expected *ModelSettings) func(*GomegaWithT, *ModelSettings) {
return func(g *GomegaWithT, settings *ModelSettings) {
g.Expect(settings.Parameters.Extra[explainerTypeKey]).To(Equal(expected.Parameters.Extra[explainerTypeKey]))
g.Expect(settings.Parameters.Extra[inferUriKey]).To(Equal(expected.Parameters.Extra[inferUriKey]))

}
}

validateLlm := func(g *GomegaWithT, expected *ModelSettings) func(*GomegaWithT, *ModelSettings) {
return func(g *GomegaWithT, settings *ModelSettings) {
g.Expect(settings.Parameters.Extra[inferUriKey]).To(Equal(expected.Parameters.Extra[inferUriKey]))
g.Expect(settings.Parameters.Extra["prompt_utils"]).To(Equal(expected.Parameters.Extra["prompt_utils"]))
}
}

tests := []test{
{
name: "basic",
data: []byte(`{"name": "iris","implementation": "mlserver_sklearn.SKLearnModel",
"parameters": {"version": "1", "extra":{}}}`),
explainerSpec: &scheduler.ExplainerSpec{
name: "set explainer - basic",
data: []byte(`{"name": "iris","implementation": "mlserver_sklearn.SKLearnModel", "parameters": {"version": "1", "extra":{}}}`),
spec: &scheduler.ExplainerSpec{
Type: "anchor_tabular",
ModelRef: getStrPr("mymodel"),
},
expected: &ModelSettings{
handler: func(m *MLServerRepositoryHandler, path string, spec interface{}, host string, port int) error {
return m.SetExplainer(path, spec.(*scheduler.ExplainerSpec), host, port)
},
assertion: validateExplainer(g, &ModelSettings{
Name: "iris",
Implementation: "mlserver_sklearn.SKLearnModel",
Parameters: &ModelParameters{
Expand All @@ -55,17 +75,19 @@ func TestSetExplainer(t *testing.T) {
inferUriKey: "http://0.0.0.0:9000/v2/models/mymodel/infer",
},
},
},
}),
},
{
name: "explainer parameters",
data: []byte(`{"name": "iris","implementation": "mlserver_sklearn.SKLearnModel",
"parameters": {"version": "1", "extra":{"init_parameters":{"threshold":0.95}}}}`),
explainerSpec: &scheduler.ExplainerSpec{
name: "set explainer - parameters",
data: []byte(`{"name": "iris","implementation": "mlserver_sklearn.SKLearnModel", "parameters": {"version": "1", "extra":{"init_parameters":{"threshold":0.95}}}}`),
spec: &scheduler.ExplainerSpec{
Type: "anchor_tabular",
ModelRef: getStrPr("mymodel"),
},
expected: &ModelSettings{
handler: func(m *MLServerRepositoryHandler, path string, spec interface{}, host string, port int) error {
return m.SetExplainer(path, spec.(*scheduler.ExplainerSpec), host, port)
},
assertion: validateExplainer(g, &ModelSettings{
Name: "iris",
Implementation: "mlserver_sklearn.SKLearnModel",
Parameters: &ModelParameters{
Expand All @@ -75,49 +97,18 @@ func TestSetExplainer(t *testing.T) {
inferUriKey: "http://0.0.0.0:9000/v2/models/mymodel/infer",
},
},
},
}),
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
modelRepoPath := t.TempDir()
settingsFile := filepath.Join(modelRepoPath, mlserverConfigFilename)
err := os.WriteFile(settingsFile, test.data, os.ModePerm)
g.Expect(err).To(BeNil())
m := &MLServerRepositoryHandler{}
err = m.SetExplainer(modelRepoPath, test.explainerSpec, envoyHost, envoyPort)
g.Expect(err).To(BeNil())
modelSettings, err := m.loadModelSettingsFromFile(settingsFile)
g.Expect(err).To(BeNil())
g.Expect(modelSettings.Parameters.Extra[explainerTypeKey]).To(Equal(test.expected.Parameters.Extra[explainerTypeKey]))
g.Expect(modelSettings.Parameters.Extra[inferUriKey]).To(Equal(test.expected.Parameters.Extra[inferUriKey]))
})
}
}

func TestSetLlm(t *testing.T) {
g := NewGomegaWithT(t)

envoyHost := "0.0.0.0"
envoyPort := 9000
type test struct {
name string
data []byte
llmSpec *scheduler.LlmSpec
expected *ModelSettings
}

getStrPr := func(str string) *string { return &str }
tests := []test{
{
name: "chat-completion",
data: []byte(`{"name": "chat-completion","implementation": "mlserver_prompt_utils.runtime.PromptRuntime",
"parameters": {"version": "1", "extra":{"prompt_utils": {"model_type": "chat.completions"}}}}`),
llmSpec: &scheduler.LlmSpec{
data: []byte(`{"name": "chat-completion","implementation": "mlserver_prompt_utils.runtime.PromptRuntime", "parameters": {"version": "1", "extra":{"prompt_utils": {"model_type": "chat.completions"}}}}`),
spec: &scheduler.LlmSpec{
ModelRef: getStrPr("mymodel"),
},
expected: &ModelSettings{
handler: func(m *MLServerRepositoryHandler, path string, spec interface{}, host string, port int) error {
return m.SetLlm(path, spec.(*scheduler.LlmSpec), host, port)
},
assertion: validateLlm(g, &ModelSettings{
Name: "chat-completion",
Implementation: "mlserver_prompt_utils.runtime.PromptRuntime",
Parameters: &ModelParameters{
Expand All @@ -129,23 +120,30 @@ func TestSetLlm(t *testing.T) {
},
},
},
},
}),
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
modelRepoPath := t.TempDir()
settingsFile := filepath.Join(modelRepoPath, mlserverConfigFilename)

// write the model settings to a file
err := os.WriteFile(settingsFile, test.data, os.ModePerm)
g.Expect(err).To(BeNil())

m := &MLServerRepositoryHandler{}
err = m.SetLlm(modelRepoPath, test.llmSpec, envoyHost, envoyPort)

// call the andler
err = test.handler(m, modelRepoPath, test.spec, envoyHost, envoyPort)
g.Expect(err).To(BeNil())

// load model settings and perofrm assertion
modelSettings, err := m.loadModelSettingsFromFile(settingsFile)
g.Expect(err).To(BeNil())
g.Expect(modelSettings.Parameters.Extra[inferUriKey]).To(Equal(test.expected.Parameters.Extra[inferUriKey]))
g.Expect(modelSettings.Parameters.Extra["prompt_utils"]).To(Equal(test.expected.Parameters.Extra["prompt_utils"]))

test.assertion(g, modelSettings)

})
}
Expand Down

0 comments on commit 5ce0906

Please sign in to comment.