Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SchawnnDev committed Apr 7, 2024
1 parent 7d3a9ee commit ff8fcba
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 28 deletions.
2 changes: 1 addition & 1 deletion internal/codegen/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (cg *PythonCodeGenerator) VisitImport(importStmt *Import) error {
// VisitImportWhat visits an ImportWhat node
func (cg *PythonCodeGenerator) VisitImportWhat(importWhat *ImportWhat) error {
if importWhat.Name == "" {
return errors.New("import what name cannot be empty")
return errors.New("import what \"name\" cannot be empty")
}

cg.append(importWhat.Name)
Expand Down
13 changes: 8 additions & 5 deletions internal/controller/model/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (ac AddController) Run(args []string, customArgs downloadermodel.Args) erro
return ac.Run(args, customArgs)
}

warningMessage, err := ac.processAdd(selectedModel, customArgs, ac.AuthorizeDownload)
warningMessage, err := ac.processAdd(selectedModel, customArgs)
if warningMessage != "" {
app.UI().Warning().Println(warningMessage)
}
Expand Down Expand Up @@ -111,23 +111,23 @@ func (ac AddController) getRequestedModel(args []string, authorizationKey string
}

// processAdd processes the selected model and tries to add it
func (ac AddController) processAdd(selectedModel model.Model, customArgs downloadermodel.Args, yes bool) (warning string, err error) {
func (ac AddController) processAdd(selectedModel model.Model, customArgs downloadermodel.Args) (warning string, err error) {
var updatedModel model.Model

// Download model is only available for model.Source == huggingface
if selectedModel.Source == model.HUGGING_FACE {
// User choose if he wishes to install the model directly
message := fmt.Sprintf("Do you wish to directly download %s?", selectedModel.Name)
selectedModel.AddToBinaryFile = !customArgs.OnlyConfiguration && (yes || app.UI().AskForUsersConfirmation(message))
selectedModel.AddToBinaryFile = !customArgs.OnlyConfiguration && (ac.AuthorizeDownload || app.UI().AskForUsersConfirmation(message))

// Validate model for download
warningMessage, valid, err := config.Validate(selectedModel, yes)
warningMessage, valid, err := config.Validate(selectedModel, ac.AuthorizeDownload)
if !valid {
return warningMessage, err
}

// Try to download model
updatedModel, err := ac.downloadModel(selectedModel, customArgs)
updatedModel, err = ac.downloadModel(selectedModel, customArgs)
if err != nil {
return warning, err
}
Expand Down Expand Up @@ -168,6 +168,9 @@ func (ac AddController) processAdd(selectedModel model.Model, customArgs downloa
selectedModel.Class = customArgs.ModelClass
selectedModel.Module = huggingface.Module(customArgs.ModelModule)
selectedModel.PipelineTag = huggingface.TextToImage // FIXME: should not be hardcoded
updatedModel = selectedModel
} else {
return "", fmt.Errorf("model source %s is not supported", selectedModel.Source)
}

// Add models to configuration file
Expand Down
104 changes: 87 additions & 17 deletions internal/controller/model/add_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/easy-model-fusion/emf-cli/pkg/huggingface"
"github.com/easy-model-fusion/emf-cli/test"
"github.com/easy-model-fusion/emf-cli/test/mock"
"os"
"testing"
)

Expand Down Expand Up @@ -206,6 +207,36 @@ func TestGetRequestedModel_WithValidArg(t *testing.T) {
test.AssertEqual(t, requestedModel.Name, expectedModel.Name)
}

// Tests getRequestedModel with valid model passed in arguments and single file enabled
func TestGetRequestedModel_WithSingleFile(t *testing.T) {
// Init
ac := AddController{
SingleFile: true,
}
var existingModels model.Models
existingModels = append(existingModels, model.Model{Name: "model1"})
existingModels = append(existingModels, model.Model{Name: "model3"})
args := []string{"model2"}
expectedModel := model.Model{Name: "model2"}

// Create full test suite with a configuration file
ts := test.TestSuite{}
_ = ts.CreateFullTestSuite(t)
defer ts.CleanTestSuite(t)
err := setupConfigFile(existingModels)
test.AssertEqual(t, err, nil, "No error expected on setting configuration file")

// Get requested model
requestedModel, err := ac.getRequestedModel(args, "")

// Assertions
test.AssertEqual(t, err, nil)
test.AssertEqual(t, requestedModel.Name, expectedModel.Name)
test.AssertEqual(t, requestedModel.Source, model.CUSTOM)
test.AssertEqual(t, requestedModel.AddToBinaryFile, true)
test.AssertEqual(t, requestedModel.IsDownloaded, true)
}

// Tests getRequestedModel with existing model requested
func TestGetRequestedModel_WithInvalidArg(t *testing.T) {
// Init
Expand Down Expand Up @@ -401,15 +432,54 @@ func TestGetRequestedModel_WithInvalidConfigPath(t *testing.T) {
test.AssertNotEqual(t, err, nil)
}

// Tests process add for single file
func TestProcessAdd_SingleFile(t *testing.T) {
// Init
ac := AddController{
SingleFile: true,
}
var existingModels model.Models
existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.HUGGING_FACE, Class: "test"})
existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.HUGGING_FACE, Class: "test"})
downloaderArgs := downloadermodel.Args{
ModelModule: string(huggingface.DIFFUSERS),
ModelClass: "Test",
ModelName: "model2",
}
selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.CUSTOM, Class: "test", Path: "test.safetensors"}

// Create full test suite with a configuration file
ts := test.TestSuite{}
_ = ts.CreateFullTestSuite(t)
defer ts.CleanTestSuite(t)
err := setupConfigFile(existingModels)
test.AssertEqual(t, err, nil, "No error expected on setting configuration file")

// create a sample file
err = os.WriteFile("test.safetensors", []byte("test"), 0644)
test.AssertEqual(t, err, nil)

// Process add
warning, err := ac.processAdd(selectedModel, downloaderArgs)
test.AssertEqual(t, err, nil)
models, err := config.GetModels()

// Assertions
test.AssertEqual(t, err, nil)
test.AssertEqual(t, warning, "")
test.AssertEqual(t, len(models), 3)
test.AssertEqual(t, models[2].Name, "model2")
}

// Tests process add
func TestProcessAdd(t *testing.T) {
func TestProcessAdd_HuggingFace(t *testing.T) {
// Init
ac := AddController{}
var existingModels model.Models
existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"})
existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"})
existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.HUGGING_FACE, Class: "test"})
existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.HUGGING_FACE, Class: "test"})
downloaderArgs := downloadermodel.Args{}
selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}
selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.HUGGING_FACE, Class: "test"}

// Create full test suite with a configuration file
ts := test.TestSuite{}
Expand All @@ -423,7 +493,7 @@ func TestProcessAdd(t *testing.T) {
app.SetDownloader(&downloader)

// Process add
warning, err := ac.processAdd(selectedModel, downloaderArgs, true)
warning, err := ac.processAdd(selectedModel, downloaderArgs)
test.AssertEqual(t, err, nil)
models, err := config.GetModels()

Expand All @@ -439,10 +509,10 @@ func TestProcessAdd_WithAccessToken(t *testing.T) {
// Init
ac := AddController{}
var existingModels model.Models
existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"})
existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"})
existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE})
existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE})
downloaderArgs := downloadermodel.Args{AccessToken: "testToken"}
selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}
selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE}

// Create full test suite with a configuration file
ts := test.TestSuite{}
Expand All @@ -456,7 +526,7 @@ func TestProcessAdd_WithAccessToken(t *testing.T) {
app.SetDownloader(&downloader)

// Process add
warning, err := ac.processAdd(selectedModel, downloaderArgs, true)
warning, err := ac.processAdd(selectedModel, downloaderArgs)
test.AssertEqual(t, err, nil)
token, err := dotenv.GetEnvValue("ACCESS_TOKEN_MODEL2")
test.AssertEqual(t, err, nil)
Expand All @@ -475,10 +545,10 @@ func TestProcessAdd_WithInvalidModel(t *testing.T) {
// Init
ac := AddController{}
var existingModels model.Models
existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"})
existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"})
existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE})
existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE})
downloaderArgs := downloadermodel.Args{}
selectedModel := model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}
selectedModel := model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE}

// Create full test suite with a configuration file
ts := test.TestSuite{}
Expand All @@ -492,7 +562,7 @@ func TestProcessAdd_WithInvalidModel(t *testing.T) {
app.SetDownloader(&downloader)

// Process add
warning, err := ac.processAdd(selectedModel, downloaderArgs, true)
warning, err := ac.processAdd(selectedModel, downloaderArgs)
test.AssertEqual(t, err, nil)
models, err := config.GetModels()

Expand All @@ -507,10 +577,10 @@ func TestProcessAdd_WithFailedDownload(t *testing.T) {
// Init
ac := AddController{}
var existingModels model.Models
existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"})
existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"})
existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE})
existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE})
downloaderArgs := downloadermodel.Args{}
selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}
selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE}

// Create full test suite with a configuration file
ts := test.TestSuite{}
Expand All @@ -524,7 +594,7 @@ func TestProcessAdd_WithFailedDownload(t *testing.T) {
app.SetDownloader(&downloader)

// Process add
warning, err := ac.processAdd(selectedModel, downloaderArgs, true)
warning, err := ac.processAdd(selectedModel, downloaderArgs)
test.AssertNotEqual(t, err, nil)
models, err := config.GetModels()

Expand Down
6 changes: 4 additions & 2 deletions internal/controller/model/remove_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ import (
)

func TestMain(m *testing.M) {
app.Init("", "")
app.InitGit("", "")
app.SetUI(&mock.MockUI{})
app.SetGit(&mock.MockGit{})
app.SetPython(&mock.MockPython{})
app.SetDownloader(&mock.MockDownloader{})
os.Exit(m.Run())
}

Expand Down
11 changes: 8 additions & 3 deletions internal/model/generation.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ var generationExcludedCharacters = []string{"-", "/", "."}
func (m *Model) GetFormattedModelName() string {
name := cases.Title(language.English, cases.Compact).String(m.Name)

// Check if name is empty, to avoid panic
if name == "" {
return name
}

// Remove special characters
for _, specialCharacter := range generationExcludedCharacters {
name = strings.ReplaceAll(name, specialCharacter, "")
Expand Down Expand Up @@ -151,9 +156,6 @@ func (m *Model) GenSuperInitParamsWithModule() []codegen.FunctionCallParameter {
Name: "device",
Value: "Devices.GPU",
},
{
Value: "**kwargs",
},
}

if m.Source == "CUSTOM" {
Expand All @@ -166,6 +168,9 @@ func (m *Model) GenSuperInitParamsWithModule() []codegen.FunctionCallParameter {

switch m.Module {
case huggingface.DIFFUSERS:
params = append(params, codegen.FunctionCallParameter{
Value: "**kwargs",
})
return params
case huggingface.TRANSFORMERS:
params = append(params, codegen.FunctionCallParameter{
Expand Down

0 comments on commit ff8fcba

Please sign in to comment.