diff --git a/internal/codegen/generator.go b/internal/codegen/generator.go index d12ce29..1e3eeb9 100644 --- a/internal/codegen/generator.go +++ b/internal/codegen/generator.go @@ -313,7 +313,7 @@ func (cg *PythonCodeGenerator) VisitImport(importStmt *Import) error { // VisitImportWhat visits an ImportWhat node func (cg *PythonCodeGenerator) VisitImportWhat(importWhat *ImportWhat) error { if importWhat.Name == "" { - return errors.New("import what name cannot be empty") + return errors.New("import what \"name\" cannot be empty") } cg.append(importWhat.Name) diff --git a/internal/controller/model/add.go b/internal/controller/model/add.go index 7a462e6..96158eb 100644 --- a/internal/controller/model/add.go +++ b/internal/controller/model/add.go @@ -30,7 +30,7 @@ func (ac AddController) Run(args []string, customArgs downloadermodel.Args) erro return ac.Run(args, customArgs) } - warningMessage, err := ac.processAdd(selectedModel, customArgs, ac.AuthorizeDownload) + warningMessage, err := ac.processAdd(selectedModel, customArgs) if warningMessage != "" { app.UI().Warning().Println(warningMessage) } @@ -111,23 +111,23 @@ func (ac AddController) getRequestedModel(args []string, authorizationKey string } // processAdd processes the selected model and tries to add it -func (ac AddController) processAdd(selectedModel model.Model, customArgs downloadermodel.Args, yes bool) (warning string, err error) { +func (ac AddController) processAdd(selectedModel model.Model, customArgs downloadermodel.Args) (warning string, err error) { var updatedModel model.Model // Download model is only available for model.Source == huggingface if selectedModel.Source == model.HUGGING_FACE { // User choose if he wishes to install the model directly message := fmt.Sprintf("Do you wish to directly download %s?", selectedModel.Name) - selectedModel.AddToBinaryFile = !customArgs.OnlyConfiguration && (yes || app.UI().AskForUsersConfirmation(message)) + selectedModel.AddToBinaryFile = !customArgs.OnlyConfiguration && (ac.AuthorizeDownload || app.UI().AskForUsersConfirmation(message)) // Validate model for download - warningMessage, valid, err := config.Validate(selectedModel, yes) + warningMessage, valid, err := config.Validate(selectedModel, ac.AuthorizeDownload) if !valid { return warningMessage, err } // Try to download model - updatedModel, err := ac.downloadModel(selectedModel, customArgs) + updatedModel, err = ac.downloadModel(selectedModel, customArgs) if err != nil { return warning, err } @@ -168,6 +168,9 @@ func (ac AddController) processAdd(selectedModel model.Model, customArgs downloa selectedModel.Class = customArgs.ModelClass selectedModel.Module = huggingface.Module(customArgs.ModelModule) selectedModel.PipelineTag = huggingface.TextToImage // FIXME: should not be hardcoded + updatedModel = selectedModel + } else { + return "", fmt.Errorf("model source %s is not supported", selectedModel.Source) } // Add models to configuration file diff --git a/internal/controller/model/add_test.go b/internal/controller/model/add_test.go index 7dbd40e..a1ebe06 100644 --- a/internal/controller/model/add_test.go +++ b/internal/controller/model/add_test.go @@ -10,6 +10,7 @@ import ( "github.com/easy-model-fusion/emf-cli/pkg/huggingface" "github.com/easy-model-fusion/emf-cli/test" "github.com/easy-model-fusion/emf-cli/test/mock" + "os" "testing" ) @@ -206,6 +207,36 @@ func TestGetRequestedModel_WithValidArg(t *testing.T) { test.AssertEqual(t, requestedModel.Name, expectedModel.Name) } +// Tests getRequestedModel with valid model passed in arguments and single file enabled +func TestGetRequestedModel_WithSingleFile(t *testing.T) { + // Init + ac := AddController{ + SingleFile: true, + } + var existingModels model.Models + existingModels = append(existingModels, model.Model{Name: "model1"}) + existingModels = append(existingModels, model.Model{Name: "model3"}) + args := []string{"model2"} + expectedModel := model.Model{Name: "model2"} + + // Create full test suite with a configuration file + ts := test.TestSuite{} + _ = ts.CreateFullTestSuite(t) + defer ts.CleanTestSuite(t) + err := setupConfigFile(existingModels) + test.AssertEqual(t, err, nil, "No error expected on setting configuration file") + + // Get requested model + requestedModel, err := ac.getRequestedModel(args, "") + + // Assertions + test.AssertEqual(t, err, nil) + test.AssertEqual(t, requestedModel.Name, expectedModel.Name) + test.AssertEqual(t, requestedModel.Source, model.CUSTOM) + test.AssertEqual(t, requestedModel.AddToBinaryFile, true) + test.AssertEqual(t, requestedModel.IsDownloaded, true) +} + // Tests getRequestedModel with existing model requested func TestGetRequestedModel_WithInvalidArg(t *testing.T) { // Init @@ -401,15 +432,54 @@ func TestGetRequestedModel_WithInvalidConfigPath(t *testing.T) { test.AssertNotEqual(t, err, nil) } +// Tests process add for single file +func TestProcessAdd_SingleFile(t *testing.T) { + // Init + ac := AddController{ + SingleFile: true, + } + var existingModels model.Models + existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.HUGGING_FACE, Class: "test"}) + existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.HUGGING_FACE, Class: "test"}) + downloaderArgs := downloadermodel.Args{ + ModelModule: string(huggingface.DIFFUSERS), + ModelClass: "Test", + ModelName: "model2", + } + selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.CUSTOM, Class: "test", Path: "test.safetensors"} + + // Create full test suite with a configuration file + ts := test.TestSuite{} + _ = ts.CreateFullTestSuite(t) + defer ts.CleanTestSuite(t) + err := setupConfigFile(existingModels) + test.AssertEqual(t, err, nil, "No error expected on setting configuration file") + + // create a sample file + err = os.WriteFile("test.safetensors", []byte("test"), 0644) + test.AssertEqual(t, err, nil) + + // Process add + warning, err := ac.processAdd(selectedModel, downloaderArgs) + test.AssertEqual(t, err, nil) + models, err := config.GetModels() + + // Assertions + test.AssertEqual(t, err, nil) + test.AssertEqual(t, warning, "") + test.AssertEqual(t, len(models), 3) + test.AssertEqual(t, models[2].Name, "model2") +} + // Tests process add -func TestProcessAdd(t *testing.T) { +func TestProcessAdd_HuggingFace(t *testing.T) { // Init ac := AddController{} var existingModels model.Models - existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}) - existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}) + existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.HUGGING_FACE, Class: "test"}) + existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.HUGGING_FACE, Class: "test"}) downloaderArgs := downloadermodel.Args{} - selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"} + selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Source: model.HUGGING_FACE, Class: "test"} // Create full test suite with a configuration file ts := test.TestSuite{} @@ -423,7 +493,7 @@ func TestProcessAdd(t *testing.T) { app.SetDownloader(&downloader) // Process add - warning, err := ac.processAdd(selectedModel, downloaderArgs, true) + warning, err := ac.processAdd(selectedModel, downloaderArgs) test.AssertEqual(t, err, nil) models, err := config.GetModels() @@ -439,10 +509,10 @@ func TestProcessAdd_WithAccessToken(t *testing.T) { // Init ac := AddController{} var existingModels model.Models - existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}) - existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}) + existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE}) + existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE}) downloaderArgs := downloadermodel.Args{AccessToken: "testToken"} - selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"} + selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE} // Create full test suite with a configuration file ts := test.TestSuite{} @@ -456,7 +526,7 @@ func TestProcessAdd_WithAccessToken(t *testing.T) { app.SetDownloader(&downloader) // Process add - warning, err := ac.processAdd(selectedModel, downloaderArgs, true) + warning, err := ac.processAdd(selectedModel, downloaderArgs) test.AssertEqual(t, err, nil) token, err := dotenv.GetEnvValue("ACCESS_TOKEN_MODEL2") test.AssertEqual(t, err, nil) @@ -475,10 +545,10 @@ func TestProcessAdd_WithInvalidModel(t *testing.T) { // Init ac := AddController{} var existingModels model.Models - existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}) - existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}) + existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE}) + existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE}) downloaderArgs := downloadermodel.Args{} - selectedModel := model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"} + selectedModel := model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE} // Create full test suite with a configuration file ts := test.TestSuite{} @@ -492,7 +562,7 @@ func TestProcessAdd_WithInvalidModel(t *testing.T) { app.SetDownloader(&downloader) // Process add - warning, err := ac.processAdd(selectedModel, downloaderArgs, true) + warning, err := ac.processAdd(selectedModel, downloaderArgs) test.AssertEqual(t, err, nil) models, err := config.GetModels() @@ -507,10 +577,10 @@ func TestProcessAdd_WithFailedDownload(t *testing.T) { // Init ac := AddController{} var existingModels model.Models - existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}) - existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"}) + existingModels = append(existingModels, model.Model{Name: "model1", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE}) + existingModels = append(existingModels, model.Model{Name: "model3", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE}) downloaderArgs := downloadermodel.Args{} - selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test"} + selectedModel := model.Model{Name: "model2", PipelineTag: huggingface.TextToImage, Module: huggingface.DIFFUSERS, Class: "test", Source: model.HUGGING_FACE} // Create full test suite with a configuration file ts := test.TestSuite{} @@ -524,7 +594,7 @@ func TestProcessAdd_WithFailedDownload(t *testing.T) { app.SetDownloader(&downloader) // Process add - warning, err := ac.processAdd(selectedModel, downloaderArgs, true) + warning, err := ac.processAdd(selectedModel, downloaderArgs) test.AssertNotEqual(t, err, nil) models, err := config.GetModels() diff --git a/internal/controller/model/remove_test.go b/internal/controller/model/remove_test.go index e7c5f1d..55be01d 100644 --- a/internal/controller/model/remove_test.go +++ b/internal/controller/model/remove_test.go @@ -12,8 +12,10 @@ import ( ) func TestMain(m *testing.M) { - app.Init("", "") - app.InitGit("", "") + app.SetUI(&mock.MockUI{}) + app.SetGit(&mock.MockGit{}) + app.SetPython(&mock.MockPython{}) + app.SetDownloader(&mock.MockDownloader{}) os.Exit(m.Run()) } diff --git a/internal/model/generation.go b/internal/model/generation.go index 5545fd8..94d14c1 100644 --- a/internal/model/generation.go +++ b/internal/model/generation.go @@ -15,6 +15,11 @@ var generationExcludedCharacters = []string{"-", "/", "."} func (m *Model) GetFormattedModelName() string { name := cases.Title(language.English, cases.Compact).String(m.Name) + // Check if name is empty, to avoid panic + if name == "" { + return name + } + // Remove special characters for _, specialCharacter := range generationExcludedCharacters { name = strings.ReplaceAll(name, specialCharacter, "") @@ -151,9 +156,6 @@ func (m *Model) GenSuperInitParamsWithModule() []codegen.FunctionCallParameter { Name: "device", Value: "Devices.GPU", }, - { - Value: "**kwargs", - }, } if m.Source == "CUSTOM" { @@ -166,6 +168,9 @@ func (m *Model) GenSuperInitParamsWithModule() []codegen.FunctionCallParameter { switch m.Module { case huggingface.DIFFUSERS: + params = append(params, codegen.FunctionCallParameter{ + Value: "**kwargs", + }) return params case huggingface.TRANSFORMERS: params = append(params, codegen.FunctionCallParameter{