Skip to content

Commit

Permalink
[uv] use torch cpu integration instead of env var
Browse files Browse the repository at this point in the history
  • Loading branch information
rmehri01 committed Jan 4, 2025
1 parent 8f4d06f commit 945210c
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 32 deletions.
124 changes: 92 additions & 32 deletions internal/backends/python/python.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import (
"io"
"os"
"os/exec"
"reflect"
"regexp"
"runtime"
"slices"
"strings"

"github.com/BurntSushi/toml"
Expand All @@ -22,22 +23,7 @@ import (

var normalizationPattern = regexp.MustCompile(`[-_.]+`)

type extraIndex struct {
// url is the location of the index
url string
// os is the operating system to override the index for, leave empty
// to override on any operating system
os string
}

var torchCpu = extraIndex{
url: "https://download.pytorch.org/whl/cpu",
os: "linux",
}

var extraIndexMap = map[string][]extraIndex{
"torch": {torchCpu},
}
var torchOverrides = []string{"torch", "torchvision"}

// this generates a mapping of pypi packages <-> modules
// moduleToPypiPackage pypiPackageToModules are provided
Expand Down Expand Up @@ -73,6 +59,12 @@ type pyprojectTOMLGroup struct {
Dependencies map[string]interface{} `json:"dependencies"`
}

type pyprojectUVIndex struct {
Name string `json:"name"`
Url string `json:"url"`
Explicit bool `json:"explicit"`
}

// pyprojectTOML represents the relevant parts of a pyproject.toml
// file.
type pyprojectTOML struct {
Expand All @@ -95,6 +87,7 @@ type pyprojectTOML struct {
} `toml:"poetry"`
Uv *struct {
Sources map[string]interface{} `toml:"sources"`
Index []pyprojectUVIndex `toml:"index"`
} `toml:"uv"`
} `toml:"tool"`
}
Expand Down Expand Up @@ -769,20 +762,77 @@ func makePythonUvBackend() api.LanguageBackend {

return pkgs
}
addExtraIndexes := func(pkgName string) {
extraIndexes, ok := extraIndexMap[pkgName]
if ok {
uvIndex := os.Getenv("UV_INDEX")

for _, index := range extraIndexes {
if strings.HasPrefix(runtime.GOOS, index.os) {
uvIndex = index.url + " " + uvIndex
appendTableOnce := func(path string, value map[string]interface{}) error {
// check if the value is already added
getOp := util.TomlEditorOp{
Op: "get",
Path: path,
}
response, err := util.ExecTomlEditor([]util.TomlEditorOp{getOp})
if err != nil {
return err
}
if len(response.Results) != 1 {
return fmt.Errorf("expected one result")
}

result := response.Results[0]
if result != nil {
if arr, ok := result.([]interface{}); ok {
if slices.ContainsFunc(arr, func(val interface{}) bool {
gotValue, ok := val.(map[string]interface{})
if !ok {
return false
}

return reflect.DeepEqual(gotValue, value)
}) {
return nil
}
}
}

// if the value isn't already present, add it
valueBytes, err := json.Marshal(value)
if err != nil {
return err
}
addOp := util.TomlEditorOp{
Op: "add",
TableHeaderPath: fmt.Sprintf("%s/[[]]", path),
Value: string(valueBytes),
}
_, err = util.ExecTomlEditor([]util.TomlEditorOp{addOp})
if err != nil {
return err
}

return nil
}
addTorchOverride := func() {
if !util.TomlEditorIsAvailable() {
util.DieSubprocess(
"toml-editor is not on the PATH, please install it with " +
"`nix profile install github:replit/toml-editor` or " +
"`cargo install --git https://github.com/replit/toml-editor` and ensure it's on the PATH.",
)
}

for _, name := range torchOverrides {
path := fmt.Sprintf("tool/uv/sources/%s", name)
value := map[string]interface{}{"index": "pytorch-cpu", "marker": "platform_system == 'Linux'"}
err := appendTableOnce(path, value)
if err != nil {
util.DieSubprocess("%s", err)
}
}

os.Setenv("UV_INDEX", uvIndex)
path := "tool/uv/index"
value := map[string]interface{}{"name": "pytorch-cpu", "url": "https://download.pytorch.org/whl/cpu", "explicit": true}
err := appendTableOnce(path, value)
if err != nil {
util.DieSubprocess("%s", err)
}
os.Setenv("UV_INDEX_STRATEGY", "unsafe-best-match")
}

b := api.LanguageBackend{
Expand Down Expand Up @@ -857,6 +907,7 @@ func makePythonUvBackend() api.LanguageBackend {
}
}

hasTorch := false
cmd := []string{"uv", "add"}
for name, coords := range pkgs {
if found, ok := moduleToPypiPackageAliases[string(name)]; ok {
Expand All @@ -867,12 +918,14 @@ func makePythonUvBackend() api.LanguageBackend {
}

cmd = append(cmd, pep440Join(coords))
addExtraIndexes(string(name))

if slices.Contains(torchOverrides, string(name)) {
hasTorch = true
}
}

specPkgs := listUvSpecfile()
for pkg := range specPkgs {
addExtraIndexes(string(pkg))
if hasTorch {
addTorchOverride()
}

util.RunCmd(cmd)
Expand All @@ -899,9 +952,16 @@ func makePythonUvBackend() api.LanguageBackend {
span, ctx := tracer.StartSpanFromContext(ctx, "uv install")
defer span.Finish()

hasTorch := false
pkgs := listUvSpecfile()
for pkg := range pkgs {
addExtraIndexes(string(pkg))
if slices.Contains(torchOverrides, string(pkg)) {
hasTorch = true
}
}

if hasTorch {
addTorchOverride()
}

util.RunCmd([]string{"uv", "sync"})
Expand Down
77 changes: 77 additions & 0 deletions internal/util/toml-editor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package util

import (
"encoding/json"
"errors"
"fmt"
"io"
"os/exec"
)

// TomlEditorOp is the format of the JSON sent to toml-editor
type TomlEditorOp struct {
Op string `json:"op"`
Path string `json:"path,omitempty"`
TableHeaderPath string `json:"table_header_path,omitempty"`
DottedPath string `json:"dotted_path,omitempty"`
Value string `json:"value,omitempty"`
}

// TomlEditorResponse is the format of the JSON sent from toml-editor
type TomlEditorResponse struct {
Status string `json:"status"`
Message string `json:"message"`
Results []interface{} `json:"results"`
}

func TomlEditorIsAvailable() bool {
_, err := exec.LookPath("toml-editor")
return err == nil
}

func ExecTomlEditor(ops []TomlEditorOp) (*TomlEditorResponse, error) {
cmd := exec.Command("toml-editor", "--path", "pyproject.toml")
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}
err = cmd.Start()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}

encoder := json.NewEncoder(stdin)
err = encoder.Encode(ops)
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}
err = stdin.Close()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}
decoder := json.NewDecoder(stdout)
var tomlEditorResponse TomlEditorResponse
err = decoder.Decode(&tomlEditorResponse)
if err != nil {
if !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("unexpected toml-editor output: %s", err)
}
}
if tomlEditorResponse.Status != "success" {
input, _ := json.Marshal(ops)
return nil, fmt.Errorf("toml-editor error with input %s: %s", input, tomlEditorResponse.Message)
}

stdout.Close()
err = cmd.Wait()
if err != nil {
input, _ := json.Marshal(ops)
return nil, fmt.Errorf("toml-editor error with input %s: %s", input, err)
}

return &tomlEditorResponse, nil
}

0 comments on commit 945210c

Please sign in to comment.