diff --git a/internal/backends/python/python.go b/internal/backends/python/python.go index bc353370..d74c8362 100644 --- a/internal/backends/python/python.go +++ b/internal/backends/python/python.go @@ -9,7 +9,7 @@ import ( "os" "os/exec" "regexp" - "runtime" + "slices" "strings" "github.com/BurntSushi/toml" @@ -22,22 +22,7 @@ import ( var normalizationPattern = regexp.MustCompile(`[-_.]+`) -type extraIndex struct { - // url is the location of the index - url string - // os is the operating system to override the index for, leave empty - // to override on any operating system - os string -} - -var torchCpu = extraIndex{ - url: "https://download.pytorch.org/whl/cpu", - os: "linux", -} - -var extraIndexMap = map[string][]extraIndex{ - "torch": {torchCpu}, -} +var torchOverrides = []string{"torch", "torchvision"} // this generates a mapping of pypi packages <-> modules // moduleToPypiPackage pypiPackageToModules are provided @@ -73,6 +58,12 @@ type pyprojectTOMLGroup struct { Dependencies map[string]interface{} `json:"dependencies"` } +type pyprojectUVIndex struct { + Name string `json:"name"` + Url string `json:"url"` + Explicit bool `json:"explicit"` +} + // pyprojectTOML represents the relevant parts of a pyproject.toml // file. type pyprojectTOML struct { @@ -95,6 +86,7 @@ type pyprojectTOML struct { } `toml:"poetry"` Uv *struct { Sources map[string]interface{} `toml:"sources"` + Index []pyprojectUVIndex `toml:"index"` } `toml:"uv"` } `toml:"tool"` } @@ -769,20 +761,39 @@ func makePythonUvBackend() api.LanguageBackend { return pkgs } - addExtraIndexes := func(pkgName string) { - extraIndexes, ok := extraIndexMap[pkgName] - if ok { - uvIndex := os.Getenv("UV_INDEX") - - for _, index := range extraIndexes { - if strings.HasPrefix(runtime.GOOS, index.os) { - uvIndex = index.url + " " + uvIndex - } - } + addTorchOverride := func() { + // check if the override is already added + b, err := os.ReadFile("pyproject.toml") + if err != nil { + return + } + if strings.Contains(string(b), "https://download.pytorch.org/whl/cpu") { + return + } - os.Setenv("UV_INDEX", uvIndex) + pyprojectFile, err := os.OpenFile("pyproject.toml", os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + return + } + defer pyprojectFile.Close() + + torchCPUOverride := ` +[tool.uv.sources] +torch = [ + { index = "pytorch-cpu" }, +] +torchvision = [ + { index = "pytorch-cpu" }, +] + +[[tool.uv.index]] +name = "pytorch-cpu" +url = "https://download.pytorch.org/whl/cpu" +explicit = true + ` + if _, err = pyprojectFile.WriteString(torchCPUOverride); err != nil { + return } - os.Setenv("UV_INDEX_STRATEGY", "unsafe-best-match") } b := api.LanguageBackend{ @@ -857,6 +868,7 @@ func makePythonUvBackend() api.LanguageBackend { } } + hasTorch := false cmd := []string{"uv", "add"} for name, coords := range pkgs { if found, ok := moduleToPypiPackageAliases[string(name)]; ok { @@ -867,12 +879,14 @@ func makePythonUvBackend() api.LanguageBackend { } cmd = append(cmd, pep440Join(coords)) - addExtraIndexes(string(name)) + + if slices.Contains(torchOverrides, string(name)) { + hasTorch = true + } } - specPkgs := listUvSpecfile() - for pkg := range specPkgs { - addExtraIndexes(string(pkg)) + if hasTorch { + addTorchOverride() } util.RunCmd(cmd) @@ -899,9 +913,16 @@ func makePythonUvBackend() api.LanguageBackend { span, ctx := tracer.StartSpanFromContext(ctx, "uv install") defer span.Finish() + hasTorch := false pkgs := listUvSpecfile() for pkg := range pkgs { - addExtraIndexes(string(pkg)) + if slices.Contains(torchOverrides, string(pkg)) { + hasTorch = true + } + } + + if hasTorch { + addTorchOverride() } util.RunCmd([]string{"uv", "sync"})