Skip to content

Commit

Permalink
#162 Hotfix tidy build path from downloaded (#163)
Browse files Browse the repository at this point in the history
* fixed build downloaded model path bug

* fixed unit tests

* added path join method

* added unit test

* adapted unit test
  • Loading branch information
RayanMarmar authored May 15, 2024
1 parent bd61a14 commit e01607f
Show file tree
Hide file tree
Showing 22 changed files with 179 additions and 151 deletions.
4 changes: 2 additions & 2 deletions internal/config/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package config

import (
"github.com/easy-model-fusion/emf-cli/internal/app"
"github.com/easy-model-fusion/emf-cli/internal/utils/fileutil"
"github.com/easy-model-fusion/emf-cli/test"
"github.com/easy-model-fusion/emf-cli/test/mock"
"os"
"path/filepath"
"testing"

"github.com/spf13/viper"
Expand All @@ -30,7 +30,7 @@ func TestLoad(t *testing.T) {
}
defer os.RemoveAll(dname)
// Create a temporary config file for the test
file, err := os.Create(filepath.Join(dname, "config.yaml"))
file, err := os.Create(fileutil.PathJoin(dname, "config.yaml"))
defer func(file *os.File) {
err = file.Close()
if err != nil {
Expand Down
11 changes: 5 additions & 6 deletions internal/config/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/easy-model-fusion/emf-cli/internal/utils/stringutil"
"github.com/spf13/viper"
"os"
"path/filepath"
)

// GetModels retrieves models from the configuration.
Expand Down Expand Up @@ -94,7 +93,7 @@ func RemoveItemPhysically(itemPath string) error {
// Cleaning up : removing every empty directory on the way to the item (from tail to head)
for i := len(directories) - 1; i >= 0; i-- {
// Build path to parent directory
path := filepath.Join(directories[:i+1]...)
path := fileutil.PathJoin(directories[:i+1]...)

// Delete directory if empty
err = fileutil.DeleteDirectoryIfEmpty(path)
Expand Down Expand Up @@ -123,7 +122,7 @@ func RemoveAllModels() (info string, err error) {

// Trying to remove every model
for _, item := range models {
modelPath := filepath.Join(app.DownloadDirectoryPath, item.Name)
modelPath := fileutil.PathJoin(app.DownloadDirectoryPath, item.Name)
spinner := app.UI().StartSpinner(fmt.Sprintf("Removing item %s...", item.Name))
err = RemoveItemPhysically(modelPath)
if err != nil {
Expand Down Expand Up @@ -161,7 +160,7 @@ func RemoveModelsByNames(models model.Models, modelsNamesToRemove []string) (war

// Trying to remove the models
for _, item := range modelsToRemove {
modelPath := filepath.Join(app.DownloadDirectoryPath, item.Name)
modelPath := fileutil.PathJoin(app.DownloadDirectoryPath, item.Name)
spinner := app.UI().StartSpinner(fmt.Sprintf("Removing item %s...", item.Name))
err = RemoveItemPhysically(modelPath)
if err != nil {
Expand Down Expand Up @@ -219,7 +218,7 @@ func Validate(current model.Model, yes bool) (warning string, success bool, err
}

// Removing model
modelPath := filepath.Join(app.DownloadDirectoryPath, current.Name)
modelPath := fileutil.PathJoin(app.DownloadDirectoryPath, current.Name)
spinner := app.UI().StartSpinner(fmt.Sprintf("Removing item %s...", current.Name))
err = RemoveItemPhysically(modelPath)
if err != nil {
Expand Down Expand Up @@ -293,5 +292,5 @@ func GenerateModelsPythonCode(models model.Models) error {
return err
}

return os.WriteFile(filepath.Join("sdk", "generated_models.py"), []byte(result), 0644)
return os.WriteFile(fileutil.PathJoin("sdk", "generated_models.py"), []byte(result), 0644)
}
47 changes: 23 additions & 24 deletions internal/config/models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/easy-model-fusion/emf-cli/test/mock"
"gopkg.in/yaml.v3"
"os"
"path/filepath"
"testing"

"github.com/easy-model-fusion/emf-cli/test"
Expand All @@ -36,7 +35,7 @@ func setupConfigDir(t *testing.T) (string, string) {
}

// Set up a temporary config file with some initial models
initialConfigFile := filepath.Join(confDir, "config.yaml")
initialConfigFile := fileutil.PathJoin(confDir, "config.yaml")

return confDir, initialConfigFile
}
Expand Down Expand Up @@ -328,7 +327,7 @@ func TestRemoveModelPhysically_NotPhysical(t *testing.T) {
func TestRemoveModelPhysically_Success(t *testing.T) {
// Init
modelToRemove := getModel(0)
modelPath := filepath.Join(app.DownloadDirectoryPath, modelToRemove.Name)
modelPath := fileutil.PathJoin(app.DownloadDirectoryPath, modelToRemove.Name)

// Create temporary model
setupModelDirectory(t, modelPath)
Expand All @@ -352,13 +351,13 @@ func TestRemoveAllModels_Success(t *testing.T) {
models := []model.Model{getModel(0), getModel(1), getModel(2)}

// Create temporary models
modelPath0 := filepath.Join(app.DownloadDirectoryPath, models[0].Name)
modelPath0 := fileutil.PathJoin(app.DownloadDirectoryPath, models[0].Name)
setupModelDirectory(t, modelPath0)
defer os.RemoveAll(modelPath0)
modelPath1 := filepath.Join(app.DownloadDirectoryPath, models[1].Name)
modelPath1 := fileutil.PathJoin(app.DownloadDirectoryPath, models[1].Name)
setupModelDirectory(t, modelPath1)
defer os.RemoveAll(modelPath1)
modelPath2 := filepath.Join(app.DownloadDirectoryPath, models[2].Name)
modelPath2 := fileutil.PathJoin(app.DownloadDirectoryPath, models[2].Name)
setupModelDirectory(t, modelPath2)
defer os.RemoveAll(modelPath2)

Expand Down Expand Up @@ -418,11 +417,11 @@ func TestRemoveModels_Success(t *testing.T) {
models := []model.Model{getModel(0), getModel(1), getModel(2)}

// Create temporary models
modelPath0 := filepath.Join(app.DownloadDirectoryPath, models[0].Name)
modelPath0 := fileutil.PathJoin(app.DownloadDirectoryPath, models[0].Name)
setupModelDirectory(t, modelPath0)
modelPath1 := filepath.Join(app.DownloadDirectoryPath, models[1].Name)
modelPath1 := fileutil.PathJoin(app.DownloadDirectoryPath, models[1].Name)
setupModelDirectory(t, modelPath1)
modelPath2 := filepath.Join(app.DownloadDirectoryPath, models[2].Name)
modelPath2 := fileutil.PathJoin(app.DownloadDirectoryPath, models[2].Name)
setupModelDirectory(t, modelPath2)
defer os.RemoveAll(app.DownloadDirectoryPath)

Expand Down Expand Up @@ -484,11 +483,11 @@ func TestRemoveModels_WithInvalidModels(t *testing.T) {
models := []model.Model{getModel(0), getModel(1), getModel(2)}

// Create temporary models
modelPath0 := filepath.Join(app.DownloadDirectoryPath, models[0].Name)
modelPath0 := fileutil.PathJoin(app.DownloadDirectoryPath, models[0].Name)
setupModelDirectory(t, modelPath0)
modelPath1 := filepath.Join(app.DownloadDirectoryPath, models[1].Name)
modelPath1 := fileutil.PathJoin(app.DownloadDirectoryPath, models[1].Name)
setupModelDirectory(t, modelPath1)
modelPath2 := filepath.Join(app.DownloadDirectoryPath, models[2].Name)
modelPath2 := fileutil.PathJoin(app.DownloadDirectoryPath, models[2].Name)
setupModelDirectory(t, modelPath2)
defer os.RemoveAll(app.DownloadDirectoryPath)

Expand Down Expand Up @@ -578,9 +577,9 @@ func TestValidate_DownloadedAndBinaryFalse_ConfirmFalse(t *testing.T) {
test.AssertEqual(t, err, nil, "Error while loading configuration file.")

// Create a temporary directory representing the model base path
modelName := filepath.Join("microsoft", "phi-2")
modelDirectory := filepath.Join(app.DownloadDirectoryPath, modelName)
modelPath := filepath.Join(modelDirectory, "model")
modelName := fileutil.PathJoin("microsoft", "phi-2")
modelDirectory := fileutil.PathJoin(app.DownloadDirectoryPath, modelName)
modelPath := fileutil.PathJoin(modelDirectory, "model")
err = os.MkdirAll(modelPath, 0750)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -616,9 +615,9 @@ func TestValidate_DownloadedAndBinaryFalse_ConfirmTrueAndRemove(t *testing.T) {
test.AssertEqual(t, err, nil, "Error while loading configuration file.")

// Create a temporary directory representing the model base path
modelName := filepath.Join("microsoft", "phi-2")
modelDirectory := filepath.Join(app.DownloadDirectoryPath, modelName)
modelPath := filepath.Join(modelDirectory, "model")
modelName := fileutil.PathJoin("microsoft", "phi-2")
modelDirectory := fileutil.PathJoin(app.DownloadDirectoryPath, modelName)
modelPath := fileutil.PathJoin(modelDirectory, "model")
err = os.MkdirAll(modelPath, 0750)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -659,9 +658,9 @@ func TestValidate_Downloaded_ConfirmFalse(t *testing.T) {
test.AssertEqual(t, err, nil, "Error while loading configuration file.")

// Create a temporary directory representing the model base path
modelName := filepath.Join("microsoft", "phi-2")
modelDirectory := filepath.Join(app.DownloadDirectoryPath, modelName)
modelPath := filepath.Join(modelDirectory, "model")
modelName := fileutil.PathJoin("microsoft", "phi-2")
modelDirectory := fileutil.PathJoin(app.DownloadDirectoryPath, modelName)
modelPath := fileutil.PathJoin(modelDirectory, "model")
err = os.MkdirAll(modelPath, 0750)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -697,9 +696,9 @@ func TestValidate_Downloaded_ConfirmTrue(t *testing.T) {
test.AssertEqual(t, err, nil, "Error while loading configuration file.")

// Create a temporary directory representing the model base path
modelName := filepath.Join("microsoft", "phi-2")
modelDirectory := filepath.Join(app.DownloadDirectoryPath, modelName)
modelPath := filepath.Join(modelDirectory, "model")
modelName := fileutil.PathJoin("microsoft", "phi-2")
modelDirectory := fileutil.PathJoin(app.DownloadDirectoryPath, modelName)
modelPath := fileutil.PathJoin(modelDirectory, "model")
err = os.MkdirAll(modelPath, 0750)
if err != nil {
t.Fatal(err)
Expand Down
4 changes: 2 additions & 2 deletions internal/controller/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ import (
"github.com/easy-model-fusion/emf-cli/internal/app"
"github.com/easy-model-fusion/emf-cli/internal/config"
"github.com/easy-model-fusion/emf-cli/internal/sdk"
"github.com/easy-model-fusion/emf-cli/internal/utils/fileutil"
"github.com/easy-model-fusion/emf-cli/internal/utils/stringutil"
"github.com/spf13/viper"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
Expand Down Expand Up @@ -193,7 +193,7 @@ func (bc BuildController) Build(libraryPath string) (err error) {
func (bc BuildController) createModelsSymbolicLink() error {
// Create symbolic link to models
modelsPath := "models"
distPath := filepath.Join(bc.DestinationDir, "models")
distPath := fileutil.PathJoin(bc.DestinationDir, "models")

app.UI().Info().Println(fmt.Sprintf("Creating symbolic link from %s to %s", modelsPath, distPath))

Expand Down
31 changes: 15 additions & 16 deletions internal/controller/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ import (
"github.com/easy-model-fusion/emf-cli/sdk"
"github.com/spf13/viper"
"os"
"path/filepath"
)

type InitController struct{}

var initDependenciesPath = filepath.Join("sdk", "requirements.txt")
var initDependenciesPath = fileutil.PathJoin("sdk", "requirements.txt")

// Run runs the init command
func (ic InitController) Run(args []string, useTorchCuda bool, customTag string) error {
Expand Down Expand Up @@ -83,7 +82,7 @@ func (ic InitController) createProject(projectName string, useTorchCuda bool, cu

// Create virtual environment
spinner := app.UI().StartSpinner("Creating virtual environment")
err = app.Python().CreateVirtualEnv(pythonPath, filepath.Join(projectName, ".venv"))
err = app.Python().CreateVirtualEnv(pythonPath, fileutil.PathJoin(projectName, ".venv"))
if err != nil {
spinner.Fail("Unable to create venv: ", err)
return err
Expand Down Expand Up @@ -128,27 +127,27 @@ func (ic InitController) createProjectFiles(projectName, sdkTag string) (err err
}()

// Copy main.py, config.yaml & .gitignore
err = fileutil.CopyEmbeddedFile(sdk.EmbeddedFiles, "main.py", filepath.Join(projectName, "main.py"))
err = fileutil.CopyEmbeddedFile(sdk.EmbeddedFiles, "main.py", fileutil.PathJoin(projectName, "main.py"))
if err != nil {
return err
}

err = fileutil.CopyEmbeddedFile(sdk.EmbeddedFiles, "config.yaml", filepath.Join(projectName, "config.yaml"))
err = fileutil.CopyEmbeddedFile(sdk.EmbeddedFiles, "config.yaml", fileutil.PathJoin(projectName, "config.yaml"))
if err != nil {
return err
}

err = fileutil.CopyEmbeddedFile(sdk.EmbeddedFiles, ".gitignore", filepath.Join(projectName, ".gitignore"))
err = fileutil.CopyEmbeddedFile(sdk.EmbeddedFiles, ".gitignore", fileutil.PathJoin(projectName, ".gitignore"))
if err != nil {
return err
}

err = fileutil.CopyEmbeddedFile(sdk.EmbeddedFiles, "README.md", filepath.Join(projectName, "README.md"))
err = fileutil.CopyEmbeddedFile(sdk.EmbeddedFiles, "README.md", fileutil.PathJoin(projectName, "README.md"))
if err != nil {
return err
}

err = fileutil.CopyEmbeddedFile(sdk.EmbeddedFiles, "requirements.txt", filepath.Join(projectName, "requirements.txt"))
err = fileutil.CopyEmbeddedFile(sdk.EmbeddedFiles, "requirements.txt", fileutil.PathJoin(projectName, "requirements.txt"))
if err != nil {
return err
}
Expand All @@ -168,13 +167,13 @@ func (ic InitController) createProjectFiles(projectName, sdkTag string) (err err
}

// Create sdk folder
err = os.Mkdir(filepath.Join(projectName, "sdk"), os.ModePerm)
err = os.Mkdir(fileutil.PathJoin(projectName, "sdk"), os.ModePerm)
if err != nil {
return err
}

// Create models folder
err = os.Mkdir(filepath.Join(projectName, "models"), os.ModePerm)
err = os.Mkdir(fileutil.PathJoin(projectName, "models"), os.ModePerm)
if err != nil {
return err
}
Expand All @@ -185,13 +184,13 @@ func (ic InitController) createProjectFiles(projectName, sdkTag string) (err err
// installDependencies installs the dependencies for the project
func (ic InitController) installDependencies(projectName string, useTorchCuda bool) (err error) {
// Install dependencies
pipPath, err := app.Python().FindVEnvExecutable(filepath.Join(projectName, ".venv"), "pip")
pipPath, err := app.Python().FindVEnvExecutable(fileutil.PathJoin(projectName, ".venv"), "pip")
if err != nil {
return err
}

spinner := app.UI().StartSpinner("Installing dependencies")
err = app.Python().InstallDependencies(pipPath, filepath.Join(projectName, initDependenciesPath))
err = app.Python().InstallDependencies(pipPath, fileutil.PathJoin(projectName, initDependenciesPath))
if err != nil {
spinner.Fail("Unable to install dependencies: ", err)
return err
Expand Down Expand Up @@ -239,7 +238,7 @@ func (ic InitController) cloneSDK(projectName, tag string) (err error) {

// Clone SDK
spinner := app.UI().StartSpinner("Cloning SDK")
err = app.G().CloneSDK(tag, filepath.Join(projectName, "sdk"))
err = app.G().CloneSDK(tag, fileutil.PathJoin(projectName, "sdk"))
if err != nil {
spinner.Fail("Unable to clone sdk: ", err)
return err
Expand All @@ -249,21 +248,21 @@ func (ic InitController) cloneSDK(projectName, tag string) (err error) {
spinner = app.UI().StartSpinner("Reorganizing SDK files")

// Move files from sdk/sdk to sdk/
err = fileutil.MoveFiles(filepath.Join(projectName, "sdk", "sdk"), filepath.Join(projectName, "sdk"))
err = fileutil.MoveFiles(fileutil.PathJoin(projectName, "sdk", "sdk"), fileutil.PathJoin(projectName, "sdk"))
if err != nil {
spinner.Fail("Unable to move SDK files: ", err)
return err
}

// remove sdk/sdk folder
err = os.RemoveAll(filepath.Join(projectName, "sdk", "sdk"))
err = os.RemoveAll(fileutil.PathJoin(projectName, "sdk", "sdk"))
if err != nil {
spinner.Fail("Unable to remove sdk/sdk folder: ", err)
return err
}

// remove .github/ folder
err = os.RemoveAll(filepath.Join(projectName, "sdk", ".github"))
err = os.RemoveAll(fileutil.PathJoin(projectName, "sdk", ".github"))
if err != nil {
spinner.Fail("Unable to remove .github folder: ", err)
return err
Expand Down
7 changes: 3 additions & 4 deletions internal/controller/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import (
"github.com/easy-model-fusion/emf-cli/internal/utils/fileutil"
"github.com/spf13/viper"
"os"
"path/filepath"
"time"
)

Expand Down Expand Up @@ -215,21 +214,21 @@ clone:
spinner = app.UI().StartSpinner("Reorganizing SDK files")

// Move files from sdk/sdk to sdk/
err = fileutil.MoveFiles(filepath.Join("sdk", "sdk"), "sdk")
err = fileutil.MoveFiles(fileutil.PathJoin("sdk", "sdk"), "sdk")
if err != nil {
spinner.Fail("Unable to move SDK files: ", err)
return err
}

// remove sdk/sdk folder
err = os.RemoveAll(filepath.Join("sdk", "sdk"))
err = os.RemoveAll(fileutil.PathJoin("sdk", "sdk"))
if err != nil {
spinner.Fail("Unable to remove sdk/sdk folder: ", err)
return err
}

// remove .github/ folder
err = os.RemoveAll(filepath.Join("sdk", ".github"))
err = os.RemoveAll(fileutil.PathJoin("sdk", ".github"))
if err != nil {
spinner.Fail("Unable to remove .github folder: ", err)
return err
Expand Down
Loading

0 comments on commit e01607f

Please sign in to comment.