Skip to content

Commit

Permalink
Fail python_packages on --x-fast
Browse files Browse the repository at this point in the history
So that we don't attempt to parse requirements specs and instead pass
them as is to monobase, i.e. `uv pip compile` for validation
  • Loading branch information
nevillelyh committed Feb 19, 2025
1 parent b49e405 commit dd187a2
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 64 deletions.
31 changes: 18 additions & 13 deletions pkg/dockerfile/fast_generator_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dockerfile

import (
"os"
"path"
"strings"
"testing"

Expand All @@ -10,11 +12,19 @@ import (
"github.com/replicate/cog/pkg/docker/dockertest"
)

func writeRequirements(t *testing.T, req string) string {
srcDir := t.TempDir()
reqFile := path.Join(srcDir, "requirements.txt")
err := os.WriteFile(reqFile, []byte(req), 0o644)
require.NoError(t, err)
return reqFile
}

func TestGenerate(t *testing.T) {
dir := t.TempDir()
build := config.Build{
PythonVersion: "3.8",
PythonPackages: []string{"torch==2.5.1"},
PythonVersion: "3.8",
PythonRequirements: writeRequirements(t, "torch==2.5.1"),
}
config := config.Config{
Build: &build,
Expand Down Expand Up @@ -55,11 +65,8 @@ func TestGenerate(t *testing.T) {
func TestGenerateUVCacheMount(t *testing.T) {
dir := t.TempDir()
build := config.Build{
PythonVersion: "3.8",
PythonPackages: []string{
"torch==2.5.1",
"catboost==1.2.7",
},
PythonVersion: "3.8",
PythonRequirements: writeRequirements(t, "torch==2.5.1\ncatboost==1.2.7"),
}
config := config.Config{
Build: &build,
Expand Down Expand Up @@ -142,10 +149,8 @@ func TestGenerateCUDA(t *testing.T) {
func TestGeneratePythonPackages(t *testing.T) {
dir := t.TempDir()
build := config.Build{
PythonVersion: "3.8",
PythonPackages: []string{
"catboost==1.2.7",
},
PythonVersion: "3.8",
PythonRequirements: writeRequirements(t, "catboost==1.2.7"),
}
config := config.Config{
Build: &build,
Expand Down Expand Up @@ -186,8 +191,8 @@ func TestGeneratePythonPackages(t *testing.T) {
func TestGenerateVerboseEnv(t *testing.T) {
dir := t.TempDir()
build := config.Build{
PythonVersion: "3.8",
PythonPackages: []string{"torch==2.5.1"},
PythonVersion: "3.8",
PythonRequirements: writeRequirements(t, "torch==2.5.1"),
}
config := config.Config{
Build: &build,
Expand Down
67 changes: 17 additions & 50 deletions pkg/requirements/requirements.go
Original file line number Diff line number Diff line change
@@ -1,85 +1,52 @@
package requirements

import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"sort"

"github.com/replicate/cog/pkg/config"
)

const REQUIREMENTS_FILE = "requirements.txt"

func GenerateRequirements(tmpDir string, cfg *config.Config) (string, error) {
// Deduplicate packages between the requirements.txt and the python packages directive.
packageNames := make(map[string]string)

// Read the python packages configuration.
for _, requirement := range cfg.Build.PythonPackages {
packageName, err := config.PackageName(requirement)
if err != nil {
return "", err
}
packageNames[packageName] = requirement
if len(cfg.Build.PythonPackages) > 0 {
return "", fmt.Errorf("python_packages is no longer supported, use python_requirements instead")
}

// Read the python requirements.
if cfg.Build.PythonRequirements != "" {
fh, err := os.Open(cfg.Build.PythonRequirements)
if err != nil {
return "", err
}
scanner := bufio.NewScanner(fh)
for scanner.Scan() {
requirement := scanner.Text()
packageName, err := config.PackageName(requirement)
if err != nil {
return "", err
}
packageNames[packageName] = requirement
}
}

// If we don't have any packages skip further processing
if len(packageNames) == 0 {
// No Python requirements
if cfg.Build.PythonRequirements == "" {
return "", nil
}

// Sort the package names by alphabetical order.
keys := make([]string, 0, len(packageNames))
for k := range packageNames {
keys = append(keys, k)
}
sort.Strings(keys)

// Render the expected contents
requirementsContent := ""
for _, k := range keys {
requirementsContent += packageNames[k] + "\n"
bs, err := os.ReadFile(cfg.Build.PythonRequirements)
if err != nil {
return "", err
}
requirements := string(bs)

// Check against the old requirements contents
// Check against the old requirements
requirementsFile := filepath.Join(tmpDir, REQUIREMENTS_FILE)
_, err := os.Stat(requirementsFile)
if !errors.Is(err, os.ErrNotExist) {
bytes, err := os.ReadFile(requirementsFile)
if _, err := os.Stat(requirementsFile); err == nil {
bs, err = os.ReadFile(requirementsFile)
if err != nil {
return "", err
}
oldRequirementsContents := string(bytes)
if oldRequirementsContents == requirementsFile {
if string(bs) == requirements {
return requirementsFile, nil
}
} else if !errors.Is(err, os.ErrNotExist) {
return "", err
}

// Write out a new requirements file
err = os.WriteFile(requirementsFile, []byte(requirementsContent), 0o644)
err = os.WriteFile(requirementsFile, []byte(requirements), 0o644)
if err != nil {
return "", err
}
return requirementsFile, nil
return requirementsFile, err
}

func CurrentRequirements(tmpDir string) (string, error) {
Expand Down
21 changes: 20 additions & 1 deletion pkg/requirements/requirements_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package requirements

import (
"os"
"path"
"path/filepath"
"testing"

Expand All @@ -9,14 +11,31 @@ import (
"github.com/replicate/cog/pkg/config"
)

func TestGenerateRequirements(t *testing.T) {
func TestPythonPackages(t *testing.T) {
tmpDir := t.TempDir()
build := config.Build{
PythonPackages: []string{"torch==2.5.1"},
}
config := config.Config{
Build: &build,
}
_, err := GenerateRequirements(tmpDir, &config)
require.ErrorContains(t, err, "python_packages is no longer supported, use python_requirements instead")
}

func TestPythonRequirements(t *testing.T) {
srcDir := t.TempDir()
reqFile := path.Join(srcDir, "requirements.txt")
err := os.WriteFile(reqFile, []byte("torch==2.5.1"), 0o644)
require.NoError(t, err)

build := config.Build{
PythonRequirements: reqFile,
}
config := config.Config{
Build: &build,
}
tmpDir := t.TempDir()
requirementsFile, err := GenerateRequirements(tmpDir, &config)
require.NoError(t, err)
require.Equal(t, filepath.Join(tmpDir, "requirements.txt"), requirementsFile)
Expand Down

0 comments on commit dd187a2

Please sign in to comment.