diff --git a/internal/backends/python/python.go b/internal/backends/python/python.go index 8746b46a..acac94e9 100644 --- a/internal/backends/python/python.go +++ b/internal/backends/python/python.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "regexp" + "runtime" "strings" "github.com/BurntSushi/toml" @@ -21,6 +22,23 @@ import ( var normalizationPattern = regexp.MustCompile(`[-_.]+`) +type extraIndex struct { + // url is the location of the index + url string + // os is the operating system to override the index for, leave empty + // to override on any operating system + os string +} + +var torchCpu = extraIndex{ + url: "https://download.pytorch.org/whl/cpu", + os: "linux", +} + +var extraIndexMap = map[string][]extraIndex{ + "torch": []extraIndex{torchCpu}, +} + // this generates a mapping of pypi packages <-> modules // moduleToPypiPackage pypiPackageToModules are provided // pypiEntryInfoResponse is a wrapper around pypiEntryInfo @@ -743,6 +761,21 @@ func makePythonUvBackend() api.LanguageBackend { return pkgs } + addExtraIndexes := func(pkgName string) { + extraIndexes, ok := extraIndexMap[string(pkgName)] + if ok { + uvIndex := os.Getenv("UV_INDEX") + + for _, index := range extraIndexes { + if strings.HasPrefix(runtime.GOOS, index.os) { + uvIndex = index.url + " " + uvIndex + } + } + + os.Setenv("UV_INDEX", uvIndex) + } + } + b := api.LanguageBackend{ Name: "python3-uv", Specfile: "pyproject.toml", @@ -824,7 +857,10 @@ func makePythonUvBackend() api.LanguageBackend { } cmd = append(cmd, pep440Join(name, spec)) + + addExtraIndexes(string(name)) } + util.RunCmd(cmd) }, Lock: func(ctx context.Context) { @@ -849,6 +885,11 @@ func makePythonUvBackend() api.LanguageBackend { span, ctx := tracer.StartSpanFromContext(ctx, "uv install") defer span.Finish() + pkgs := listUvSpecfile() + for pkg := range pkgs { + addExtraIndexes(string(pkg)) + } + util.RunCmd([]string{"uv", "sync"}) }, ListSpecfile: func(mergeAllGroups bool) map[api.PkgName]api.PkgSpec {