Skip to content

Commit

Permalink
[uv] use torch cpu integration instead of env var
Browse files Browse the repository at this point in the history
  • Loading branch information
rmehri01 committed Dec 31, 2024
1 parent 8f4d06f commit de569af
Showing 1 changed file with 55 additions and 34 deletions.
89 changes: 55 additions & 34 deletions internal/backends/python/python.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"os"
"os/exec"
"regexp"
"runtime"
"slices"
"strings"

"github.com/BurntSushi/toml"
Expand All @@ -22,22 +22,7 @@ import (

var normalizationPattern = regexp.MustCompile(`[-_.]+`)

type extraIndex struct {
// url is the location of the index
url string
// os is the operating system to override the index for, leave empty
// to override on any operating system
os string
}

var torchCpu = extraIndex{
url: "https://download.pytorch.org/whl/cpu",
os: "linux",
}

var extraIndexMap = map[string][]extraIndex{
"torch": {torchCpu},
}
var torchOverrides = []string{"torch", "torchvision"}

// this generates a mapping of pypi packages <-> modules
// moduleToPypiPackage pypiPackageToModules are provided
Expand Down Expand Up @@ -73,6 +58,12 @@ type pyprojectTOMLGroup struct {
Dependencies map[string]interface{} `json:"dependencies"`
}

type pyprojectUVIndex struct {
Name string `json:"name"`
Url string `json:"url"`
Explicit bool `json:"explicit"`
}

// pyprojectTOML represents the relevant parts of a pyproject.toml
// file.
type pyprojectTOML struct {
Expand All @@ -95,6 +86,7 @@ type pyprojectTOML struct {
} `toml:"poetry"`
Uv *struct {
Sources map[string]interface{} `toml:"sources"`
Index []pyprojectUVIndex `toml:"index"`
} `toml:"uv"`
} `toml:"tool"`
}
Expand Down Expand Up @@ -769,20 +761,39 @@ func makePythonUvBackend() api.LanguageBackend {

return pkgs
}
addExtraIndexes := func(pkgName string) {
extraIndexes, ok := extraIndexMap[pkgName]
if ok {
uvIndex := os.Getenv("UV_INDEX")

for _, index := range extraIndexes {
if strings.HasPrefix(runtime.GOOS, index.os) {
uvIndex = index.url + " " + uvIndex
}
}
addTorchOverride := func() {
// check if the override is already added
b, err := os.ReadFile("pyproject.toml")
if err != nil {
return
}
if strings.Contains(string(b), "https://download.pytorch.org/whl/cpu") {
return
}

os.Setenv("UV_INDEX", uvIndex)
pyprojectFile, err := os.OpenFile("pyproject.toml", os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
return
}
defer pyprojectFile.Close()

torchCPUOverride := `
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu" },
]
torchvision = [
{ index = "pytorch-cpu" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
`
if _, err = pyprojectFile.WriteString(torchCPUOverride); err != nil {
return
}
os.Setenv("UV_INDEX_STRATEGY", "unsafe-best-match")
}

b := api.LanguageBackend{
Expand Down Expand Up @@ -857,6 +868,7 @@ func makePythonUvBackend() api.LanguageBackend {
}
}

hasTorch := false
cmd := []string{"uv", "add"}
for name, coords := range pkgs {
if found, ok := moduleToPypiPackageAliases[string(name)]; ok {
Expand All @@ -867,12 +879,14 @@ func makePythonUvBackend() api.LanguageBackend {
}

cmd = append(cmd, pep440Join(coords))
addExtraIndexes(string(name))

if slices.Contains(torchOverrides, string(name)) {
hasTorch = true
}
}

specPkgs := listUvSpecfile()
for pkg := range specPkgs {
addExtraIndexes(string(pkg))
if hasTorch {
addTorchOverride()
}

util.RunCmd(cmd)
Expand All @@ -899,9 +913,16 @@ func makePythonUvBackend() api.LanguageBackend {
span, ctx := tracer.StartSpanFromContext(ctx, "uv install")
defer span.Finish()

hasTorch := false
pkgs := listUvSpecfile()
for pkg := range pkgs {
addExtraIndexes(string(pkg))
if slices.Contains(torchOverrides, string(pkg)) {
hasTorch = true
}
}

if hasTorch {
addTorchOverride()
}

util.RunCmd([]string{"uv", "sync"})
Expand Down

0 comments on commit de569af

Please sign in to comment.