Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail python_packages on --x-fast #2162

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions pkg/dockerfile/fast_generator_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dockerfile

import (
"os"
"path"
"strings"
"testing"

Expand All @@ -10,11 +12,19 @@ import (
"github.com/replicate/cog/pkg/docker/dockertest"
)

func writeRequirements(t *testing.T, req string) string {
srcDir := t.TempDir()
reqFile := path.Join(srcDir, "requirements.txt")
err := os.WriteFile(reqFile, []byte(req), 0o644)
require.NoError(t, err)
return reqFile
}

func TestGenerate(t *testing.T) {
dir := t.TempDir()
build := config.Build{
PythonVersion: "3.8",
PythonPackages: []string{"torch==2.5.1"},
PythonVersion: "3.8",
PythonRequirements: writeRequirements(t, "torch==2.5.1"),
}
config := config.Config{
Build: &build,
Expand Down Expand Up @@ -55,11 +65,8 @@ func TestGenerate(t *testing.T) {
func TestGenerateUVCacheMount(t *testing.T) {
dir := t.TempDir()
build := config.Build{
PythonVersion: "3.8",
PythonPackages: []string{
"torch==2.5.1",
"catboost==1.2.7",
},
PythonVersion: "3.8",
PythonRequirements: writeRequirements(t, "torch==2.5.1\ncatboost==1.2.7"),
}
config := config.Config{
Build: &build,
Expand Down Expand Up @@ -142,10 +149,8 @@ func TestGenerateCUDA(t *testing.T) {
func TestGeneratePythonPackages(t *testing.T) {
dir := t.TempDir()
build := config.Build{
PythonVersion: "3.8",
PythonPackages: []string{
"catboost==1.2.7",
},
PythonVersion: "3.8",
PythonRequirements: writeRequirements(t, "catboost==1.2.7"),
}
config := config.Config{
Build: &build,
Expand Down Expand Up @@ -186,8 +191,8 @@ func TestGeneratePythonPackages(t *testing.T) {
func TestGenerateVerboseEnv(t *testing.T) {
dir := t.TempDir()
build := config.Build{
PythonVersion: "3.8",
PythonPackages: []string{"torch==2.5.1"},
PythonVersion: "3.8",
PythonRequirements: writeRequirements(t, "torch==2.5.1"),
}
config := config.Config{
Build: &build,
Expand Down
67 changes: 17 additions & 50 deletions pkg/requirements/requirements.go
Original file line number Diff line number Diff line change
@@ -1,85 +1,52 @@
package requirements

import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"sort"

"github.com/replicate/cog/pkg/config"
)

const REQUIREMENTS_FILE = "requirements.txt"

func GenerateRequirements(tmpDir string, cfg *config.Config) (string, error) {
// Deduplicate packages between the requirements.txt and the python packages directive.
packageNames := make(map[string]string)

// Read the python packages configuration.
for _, requirement := range cfg.Build.PythonPackages {
packageName, err := config.PackageName(requirement)
if err != nil {
return "", err
}
packageNames[packageName] = requirement
if len(cfg.Build.PythonPackages) > 0 {
return "", fmt.Errorf("python_packages is no longer supported, use python_requirements instead")
}

// Read the python requirements.
if cfg.Build.PythonRequirements != "" {
fh, err := os.Open(cfg.Build.PythonRequirements)
if err != nil {
return "", err
}
scanner := bufio.NewScanner(fh)
for scanner.Scan() {
requirement := scanner.Text()
packageName, err := config.PackageName(requirement)
if err != nil {
return "", err
}
packageNames[packageName] = requirement
}
}

// If we don't have any packages skip further processing
if len(packageNames) == 0 {
// No Python requirements
if cfg.Build.PythonRequirements == "" {
return "", nil
}

// Sort the package names by alphabetical order.
keys := make([]string, 0, len(packageNames))
for k := range packageNames {
keys = append(keys, k)
}
sort.Strings(keys)

// Render the expected contents
requirementsContent := ""
for _, k := range keys {
requirementsContent += packageNames[k] + "\n"
bs, err := os.ReadFile(cfg.Build.PythonRequirements)
if err != nil {
return "", err
}
requirements := string(bs)

// Check against the old requirements contents
// Check against the old requirements
requirementsFile := filepath.Join(tmpDir, REQUIREMENTS_FILE)
_, err := os.Stat(requirementsFile)
if !errors.Is(err, os.ErrNotExist) {
bytes, err := os.ReadFile(requirementsFile)
if _, err := os.Stat(requirementsFile); err == nil {
bs, err = os.ReadFile(requirementsFile)
if err != nil {
return "", err
}
oldRequirementsContents := string(bytes)
if oldRequirementsContents == requirementsFile {
if string(bs) == requirements {
return requirementsFile, nil
}
} else if !errors.Is(err, os.ErrNotExist) {
return "", err
}

// Write out a new requirements file
err = os.WriteFile(requirementsFile, []byte(requirementsContent), 0o644)
err = os.WriteFile(requirementsFile, []byte(requirements), 0o644)
if err != nil {
return "", err
}
return requirementsFile, nil
return requirementsFile, err
}

func CurrentRequirements(tmpDir string) (string, error) {
Expand Down
21 changes: 20 additions & 1 deletion pkg/requirements/requirements_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package requirements

import (
"os"
"path"
"path/filepath"
"testing"

Expand All @@ -9,14 +11,31 @@ import (
"github.com/replicate/cog/pkg/config"
)

func TestGenerateRequirements(t *testing.T) {
func TestPythonPackages(t *testing.T) {
tmpDir := t.TempDir()
build := config.Build{
PythonPackages: []string{"torch==2.5.1"},
}
config := config.Config{
Build: &build,
}
_, err := GenerateRequirements(tmpDir, &config)
require.ErrorContains(t, err, "python_packages is no longer supported, use python_requirements instead")
}

func TestPythonRequirements(t *testing.T) {
srcDir := t.TempDir()
reqFile := path.Join(srcDir, "requirements.txt")
err := os.WriteFile(reqFile, []byte("torch==2.5.1"), 0o644)
require.NoError(t, err)

build := config.Build{
PythonRequirements: reqFile,
}
config := config.Config{
Build: &build,
}
tmpDir := t.TempDir()
requirementsFile, err := GenerateRequirements(tmpDir, &config)
require.NoError(t, err)
require.Equal(t, filepath.Join(tmpDir, "requirements.txt"), requirementsFile)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
build:
python_version: "3.12"
python_packages:
- "torch==2.5.0"
- "beautifulsoup4==4.12.3"
python_requirements: requirements.txt
system_packages:
- "git"
predict: "predict.py:Predictor"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch==2.5.0
beautifulsoup4==4.12.3