Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(app): Install local repository with third-party dependencies in executor #824

Merged
merged 4 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies = [
"sqlmodel==0.0.18",
"temporalio==1.6.0",
"tenacity==8.3.0",
"tomli>=2.2.1",
"uv==0.4.10",
"uvicorn==0.29.0",
"virtualenv==20.27.0",
Expand Down
90 changes: 89 additions & 1 deletion tests/unit/test_parse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from pathlib import Path

from tracecat.expressions.functions import eval_jsonpath
from tracecat.parse import traverse_expressions, traverse_leaves
from tracecat.parse import (
get_pyproject_toml_required_deps,
traverse_expressions,
traverse_leaves,
)


def test_iter_dict_leaves():
Expand Down Expand Up @@ -84,3 +90,85 @@ def test_traverse_expressions():
assert list(traverse_expressions(data)) == []
data = {"test": {}, "list": []}
assert list(traverse_expressions(data)) == []


def test_parse_pyproject_toml_deps_basic(tmp_path: Path) -> None:
"""Test parsing a basic pyproject.toml with only direct dependencies."""
# Create a temporary pyproject.toml file
content = """
[project]
dependencies = [
"requests>=2.28.0",
"pydantic~=2.0",
]
"""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text(content)

# Test parsing
deps = get_pyproject_toml_required_deps(pyproject)
assert len(deps) == 2
assert "requests>=2.28.0" in deps
assert "pydantic~=2.0" in deps


def test_parse_pyproject_toml_deps_with_optional(tmp_path: Path) -> None:
"""Test parsing a pyproject.toml with both direct and optional dependencies.
Note: Optional dependencies are not included in the result."""
content = """
[project]
dependencies = [
"requests>=2.28.0",
]
[project.optional-dependencies]
test = [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
]
dev = [
"black>=23.0.0",
]
"""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text(content)

# Test parsing - should only include direct dependencies
deps = get_pyproject_toml_required_deps(pyproject)
assert len(deps) == 1
assert "requests>=2.28.0" in deps


def test_parse_pyproject_toml_deps_empty_project(tmp_path: Path) -> None:
"""Test parsing a pyproject.toml with no dependencies."""
content = """
[project]
name = "test-project"
version = "0.1.0"
"""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text(content)

# Test parsing
deps = get_pyproject_toml_required_deps(pyproject)
assert len(deps) == 0


def test_parse_pyproject_toml_deps_missing_file(tmp_path: Path) -> None:
"""Test handling of missing pyproject.toml file."""
nonexistent_file = tmp_path / "nonexistent.toml"
deps = get_pyproject_toml_required_deps(nonexistent_file)
assert deps == []


def test_parse_pyproject_toml_deps_invalid_toml(tmp_path: Path) -> None:
"""Test handling of invalid TOML content."""
content = """
[project
invalid toml content
"""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text(content)

# Test parsing
deps = get_pyproject_toml_required_deps(pyproject)
assert deps == []
33 changes: 31 additions & 2 deletions tracecat/executor/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import traceback
from collections.abc import Iterator, Mapping
from pathlib import Path
from typing import Any, cast

import ray
Expand Down Expand Up @@ -33,7 +34,7 @@
)
from tracecat.git import prepare_git_url
from tracecat.logger import logger
from tracecat.parse import traverse_leaves
from tracecat.parse import get_pyproject_toml_required_deps, traverse_leaves
from tracecat.registry.actions.models import BoundRegistryAction
from tracecat.registry.actions.service import RegistryActionsService
from tracecat.secrets.common import apply_masks_object
Expand Down Expand Up @@ -295,11 +296,39 @@ async def run_action_on_ray_cluster(
additional_vars: dict[str, Any] = {}

# Add git URL to pip dependencies if SHA is present
pip_deps = []
if ctx.git_url and ctx.git_url.ref:
url = ctx.git_url.to_url()
additional_vars["pip"] = [url]
pip_deps.append(url)
logger.trace("Adding git URL to runtime env", git_url=ctx.git_url, url=url)

# If we have a local registry, we need to add it to the runtime env
if config.TRACECAT__LOCAL_REPOSITORY_ENABLED:
local_repo_path = config.TRACECAT__LOCAL_REPOSITORY_CONTAINER_PATH
logger.info(
"Adding local repository and required dependencies to runtime env",
local_repo_path=local_repo_path,
)

# Try pyproject.toml first
pyproject_path = Path(local_repo_path) / "pyproject.toml"
if not pyproject_path.exists():
logger.error(
"No pyproject.toml found in local repository", path=pyproject_path
)
raise ValueError("No pyproject.toml found in local repository")
required_deps = await asyncio.to_thread(
get_pyproject_toml_required_deps, pyproject_path
)
logger.debug(
"Found pyproject.toml with required dependencies", deps=required_deps
)
pip_deps.extend([local_repo_path, *required_deps])

# Add pip dependencies to runtime env
if pip_deps:
additional_vars["pip"] = pip_deps

runtime_env = RuntimeEnv(env_vars=env_vars, **additional_vars)

logger.info("Running action on ray cluster", runtime_env=runtime_env)
Expand Down
20 changes: 19 additions & 1 deletion tracecat/parse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import re
from collections.abc import Iterator
from typing import Any
from pathlib import Path
from typing import Any, cast
from urllib.parse import urlparse, urlunparse

import tomli

from tracecat.expressions import patterns
from tracecat.logger import logger


def insert_obj_by_path(
Expand Down Expand Up @@ -62,3 +66,17 @@ def safe_url(url: str) -> str:
# Note that we do not recommend passing credentials in the url.
cleaned_url = urlunparse((url_obj.scheme, url_obj.netloc, url_obj.path, "", "", ""))
return cleaned_url


def get_pyproject_toml_required_deps(pyproject_path: Path) -> list[str]:
"""Parse pyproject.toml to extract dependencies."""
try:
with pyproject_path.open("rb") as f:
pyproject = tomli.load(f)

# Get dependencies from pyproject.toml
project = pyproject.get("project", {})
return cast(list[str], project.get("dependencies", []))
except Exception as e:
logger.error("Error parsing pyproject.toml", error=e)
return []
7 changes: 2 additions & 5 deletions tracecat/registry/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,13 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None:
raise RegistryError(f"Local git repository not found: {repo_path}")

# Check that there's either pyproject.toml or setup.py
if (
not repo_path.joinpath("pyproject.toml").exists()
and not repo_path.joinpath("setup.py").exists()
):
if not repo_path.joinpath("pyproject.toml").exists():
# expand the path to the host path
if host_path := config.TRACECAT__LOCAL_REPOSITORY_PATH:
host_path = Path(host_path).expanduser()
logger.debug("Host path", host_path=host_path)
raise RegistryError(
"Local repository does not contain pyproject.toml or setup.py. "
"Local repository does not contain pyproject.toml. "
"Please ensure TRACECAT__LOCAL_REPOSITORY_PATH points to a valid Python package."
f"Host path: {host_path}"
)
Expand Down