Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[uv] use torch cpu integration instead of env var #323

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 142 additions & 32 deletions internal/backends/python/python.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ package python
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"regexp"
"runtime"
"slices"
"strings"

"github.com/BurntSushi/toml"
Expand All @@ -22,22 +23,7 @@ import (

var normalizationPattern = regexp.MustCompile(`[-_.]+`)

type extraIndex struct {
// url is the location of the index
url string
// os is the operating system to override the index for, leave empty
// to override on any operating system
os string
}

var torchCpu = extraIndex{
url: "https://download.pytorch.org/whl/cpu",
os: "linux",
}

var extraIndexMap = map[string][]extraIndex{
"torch": {torchCpu},
}
var torchOverrides = []string{"torch", "torchvision"}

// this generates a mapping of pypi packages <-> modules
// moduleToPypiPackage pypiPackageToModules are provided
Expand Down Expand Up @@ -73,6 +59,12 @@ type pyprojectTOMLGroup struct {
Dependencies map[string]interface{} `json:"dependencies"`
}

type pyprojectUVIndex struct {
Name string `json:"name"`
Url string `json:"url"`
Explicit bool `json:"explicit"`
}

// pyprojectTOML represents the relevant parts of a pyproject.toml
// file.
type pyprojectTOML struct {
Expand All @@ -95,6 +87,7 @@ type pyprojectTOML struct {
} `toml:"poetry"`
Uv *struct {
Sources map[string]interface{} `toml:"sources"`
Index []pyprojectUVIndex `toml:"index"`
} `toml:"uv"`
} `toml:"tool"`
}
Expand Down Expand Up @@ -769,20 +762,121 @@ func makePythonUvBackend() api.LanguageBackend {

return pkgs
}
addExtraIndexes := func(pkgName string) {
extraIndexes, ok := extraIndexMap[pkgName]
if ok {
uvIndex := os.Getenv("UV_INDEX")

for _, index := range extraIndexes {
if strings.HasPrefix(runtime.GOOS, index.os) {
uvIndex = index.url + " " + uvIndex
addTorchOverride := func() error {
if !util.TomlEditorIsAvailable() {
return errors.New(
"toml-editor is not on the PATH, please install it with " +
"`nix profile install github:replit/toml-editor` or " +
"`cargo install --git https://github.com/replit/toml-editor` and ensure it's on the PATH.",
)
}

// check if the torch cpu index is already present
getOp := util.TomlEditorOp{
Op: "get",
Path: "tool/uv/index",
}
response, err := util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{getOp})
if err != nil {
return err
}
if len(response.Results) != 1 {
return fmt.Errorf("expected one result")
}

hasIndex := false
result := response.Results[0]
if result != nil {
if arr, ok := result.([]interface{}); ok {
for _, value := range arr {
index, ok := value.(map[string]interface{})
if !ok {
continue
}
name, ok := index["name"]
if !ok {
continue
}
hasIndex = name == "pytorch-cpu"
if hasIndex {
break
}
}
}
}
if !hasIndex {
value := map[string]interface{}{"name": "pytorch-cpu", "url": "https://download.pytorch.org/whl/cpu", "explicit": true}
valueBytes, err := json.Marshal(value)
if err != nil {
return err
}
addOp := util.TomlEditorOp{
Op: "add",
TableHeaderPath: "tool/uv/index/[[]]",
Value: string(valueBytes),
}
_, err = util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{addOp})
if err != nil {
return err
}
}

for _, name := range torchOverrides {
// check if the source is already present
getOp := util.TomlEditorOp{
Op: "get",
Path: fmt.Sprintf("tool/uv/sources/%s", name),
}
response, err := util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{getOp})
if err != nil {
util.Log(fmt.Sprintf("error while checking override '%s': %s", name, err))
continue
}
if len(response.Results) != 1 {
util.Log(fmt.Sprintf("error while checking override '%s': expected one result", name))
continue
}

os.Setenv("UV_INDEX", uvIndex)
hasSource := false
result := response.Results[0]
if result != nil {
if arr, ok := result.([]interface{}); ok {
for _, value := range arr {
source, ok := value.(map[string]interface{})
if !ok {
continue
}
name, ok := source["index"]
if !ok {
continue
}
hasSource = name == "pytorch-cpu"
if hasSource {
break
}
}
}
}
if !hasSource {
value := map[string]interface{}{"index": "pytorch-cpu", "marker": "platform_system == 'Linux'"}
valueBytes, err := json.Marshal(value)
if err != nil {
return err
}
addOp := util.TomlEditorOp{
Op: "add",
TableHeaderPath: "tool/uv/sources",
Path: fmt.Sprintf("%s/[]", name),
Value: string(valueBytes),
}
_, err = util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{addOp})
if err != nil {
return err
}
}
}
os.Setenv("UV_INDEX_STRATEGY", "unsafe-best-match")

return nil
}

b := api.LanguageBackend{
Expand Down Expand Up @@ -857,6 +951,7 @@ func makePythonUvBackend() api.LanguageBackend {
}
}

hasTorch := false
cmd := []string{"uv", "add"}
for name, coords := range pkgs {
if found, ok := moduleToPypiPackageAliases[string(name)]; ok {
Expand All @@ -867,12 +962,17 @@ func makePythonUvBackend() api.LanguageBackend {
}

cmd = append(cmd, pep440Join(coords))
addExtraIndexes(string(name))

if slices.Contains(torchOverrides, string(name)) {
hasTorch = true
}
}

specPkgs := listUvSpecfile()
for pkg := range specPkgs {
addExtraIndexes(string(pkg))
if hasTorch {
err := addTorchOverride()
if err != nil {
util.DieSubprocess("%s", err)
}
}

util.RunCmd(cmd)
Expand All @@ -899,9 +999,19 @@ func makePythonUvBackend() api.LanguageBackend {
span, ctx := tracer.StartSpanFromContext(ctx, "uv install")
defer span.Finish()

hasTorch := false
pkgs := listUvSpecfile()
for pkg := range pkgs {
addExtraIndexes(string(pkg))
if slices.Contains(torchOverrides, string(pkg)) {
hasTorch = true
}
}

if hasTorch {
err := addTorchOverride()
if err != nil {
util.DieSubprocess("%s", err)
}
}

util.RunCmd([]string{"uv", "sync"})
Expand Down
77 changes: 77 additions & 0 deletions internal/util/toml-editor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package util

import (
"encoding/json"
"errors"
"fmt"
"io"
"os/exec"
)

// TomlEditorOp is the format of the JSON sent to toml-editor
type TomlEditorOp struct {
Op string `json:"op"`
Path string `json:"path,omitempty"`
TableHeaderPath string `json:"table_header_path,omitempty"`
Value string `json:"value,omitempty"`
}

// TomlEditorResponse is the format of the JSON sent from toml-editor
type TomlEditorResponse struct {
Status string `json:"status"`
Message string `json:"message"`
Results []interface{} `json:"results"`
}

func TomlEditorIsAvailable() bool {
_, err := exec.LookPath("toml-editor")
return err == nil
}

func ExecTomlEditor(tomlPath string, ops []TomlEditorOp) (*TomlEditorResponse, error) {
cmd := exec.Command("toml-editor", "--path", tomlPath)
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}
err = cmd.Start()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}

encoder := json.NewEncoder(stdin)
err = encoder.Encode(ops)
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}
decoder := json.NewDecoder(stdout)
var tomlEditorResponse TomlEditorResponse
err = decoder.Decode(&tomlEditorResponse)
if err != nil {
if !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("unexpected toml-editor output: %s", err)
}
}
if tomlEditorResponse.Status != "success" {
input, _ := json.Marshal(ops)
return nil, fmt.Errorf("toml-editor error with input %s: %s", input, tomlEditorResponse.Message)
}

stdout.Close()
err = stdin.Close()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}

err = cmd.Wait()
if err != nil {
input, _ := json.Marshal(ops)
return nil, fmt.Errorf("toml-editor error with input %s: %s", input, err)
}

return &tomlEditorResponse, nil
}
Loading