diff --git a/pyproject.toml b/pyproject.toml index 2e4d0ef81..952d1aab3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ "sqlmodel==0.0.18", "temporalio==1.6.0", "tenacity==8.3.0", + "tomli>=2.2.1", "uv==0.4.10", "uvicorn==0.29.0", "virtualenv==20.27.0", diff --git a/tests/unit/test_parse.py b/tests/unit/test_parse.py index b8689b29b..4bf947a7d 100644 --- a/tests/unit/test_parse.py +++ b/tests/unit/test_parse.py @@ -1,5 +1,11 @@ +from pathlib import Path + from tracecat.expressions.functions import eval_jsonpath -from tracecat.parse import traverse_expressions, traverse_leaves +from tracecat.parse import ( + get_pyproject_toml_required_deps, + traverse_expressions, + traverse_leaves, +) def test_iter_dict_leaves(): @@ -84,3 +90,85 @@ def test_traverse_expressions(): assert list(traverse_expressions(data)) == [] data = {"test": {}, "list": []} assert list(traverse_expressions(data)) == [] + + +def test_parse_pyproject_toml_deps_basic(tmp_path: Path) -> None: + """Test parsing a basic pyproject.toml with only direct dependencies.""" + # Create a temporary pyproject.toml file + content = """ +[project] +dependencies = [ + "requests>=2.28.0", + "pydantic~=2.0", +] +""" + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text(content) + + # Test parsing + deps = get_pyproject_toml_required_deps(pyproject) + assert len(deps) == 2 + assert "requests>=2.28.0" in deps + assert "pydantic~=2.0" in deps + + +def test_parse_pyproject_toml_deps_with_optional(tmp_path: Path) -> None: + """Test parsing a pyproject.toml with both direct and optional dependencies. + Note: Optional dependencies are not included in the result.""" + content = """ +[project] +dependencies = [ + "requests>=2.28.0", +] +[project.optional-dependencies] +test = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", +] +dev = [ + "black>=23.0.0", +] +""" + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text(content) + + # Test parsing - should only include direct dependencies + deps = get_pyproject_toml_required_deps(pyproject) + assert len(deps) == 1 + assert "requests>=2.28.0" in deps + + +def test_parse_pyproject_toml_deps_empty_project(tmp_path: Path) -> None: + """Test parsing a pyproject.toml with no dependencies.""" + content = """ +[project] +name = "test-project" +version = "0.1.0" +""" + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text(content) + + # Test parsing + deps = get_pyproject_toml_required_deps(pyproject) + assert len(deps) == 0 + + +def test_parse_pyproject_toml_deps_missing_file(tmp_path: Path) -> None: + """Test handling of missing pyproject.toml file.""" + nonexistent_file = tmp_path / "nonexistent.toml" + deps = get_pyproject_toml_required_deps(nonexistent_file) + assert deps == [] + + +def test_parse_pyproject_toml_deps_invalid_toml(tmp_path: Path) -> None: + """Test handling of invalid TOML content.""" + content = """ +[project +invalid toml content +""" + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text(content) + + # Test parsing + deps = get_pyproject_toml_required_deps(pyproject) + assert deps == [] diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index d3f23ccee..766a5c457 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -3,6 +3,7 @@ import asyncio import traceback from collections.abc import Iterator, Mapping +from pathlib import Path from typing import Any, cast import ray @@ -33,7 +34,7 @@ ) from tracecat.git import prepare_git_url from tracecat.logger import logger -from tracecat.parse import traverse_leaves +from tracecat.parse import get_pyproject_toml_required_deps, traverse_leaves from tracecat.registry.actions.models import BoundRegistryAction from tracecat.registry.actions.service import RegistryActionsService from tracecat.secrets.common import apply_masks_object @@ -295,11 +296,39 @@ async def run_action_on_ray_cluster( additional_vars: dict[str, Any] = {} # Add git URL to pip dependencies if SHA is present + pip_deps = [] if ctx.git_url and ctx.git_url.ref: url = ctx.git_url.to_url() - additional_vars["pip"] = [url] + pip_deps.append(url) logger.trace("Adding git URL to runtime env", git_url=ctx.git_url, url=url) + # If we have a local registry, we need to add it to the runtime env + if config.TRACECAT__LOCAL_REPOSITORY_ENABLED: + local_repo_path = config.TRACECAT__LOCAL_REPOSITORY_CONTAINER_PATH + logger.info( + "Adding local repository and required dependencies to runtime env", + local_repo_path=local_repo_path, + ) + + # Try pyproject.toml first + pyproject_path = Path(local_repo_path) / "pyproject.toml" + if not pyproject_path.exists(): + logger.error( + "No pyproject.toml found in local repository", path=pyproject_path + ) + raise ValueError("No pyproject.toml found in local repository") + required_deps = await asyncio.to_thread( + get_pyproject_toml_required_deps, pyproject_path + ) + logger.debug( + "Found pyproject.toml with required dependencies", deps=required_deps + ) + pip_deps.extend([local_repo_path, *required_deps]) + + # Add pip dependencies to runtime env + if pip_deps: + additional_vars["pip"] = pip_deps + runtime_env = RuntimeEnv(env_vars=env_vars, **additional_vars) logger.info("Running action on ray cluster", runtime_env=runtime_env) diff --git a/tracecat/parse.py b/tracecat/parse.py index b8aec6c05..d57868952 100644 --- a/tracecat/parse.py +++ b/tracecat/parse.py @@ -1,9 +1,13 @@ import re from collections.abc import Iterator -from typing import Any +from pathlib import Path +from typing import Any, cast from urllib.parse import urlparse, urlunparse +import tomli + from tracecat.expressions import patterns +from tracecat.logger import logger def insert_obj_by_path( @@ -62,3 +66,17 @@ def safe_url(url: str) -> str: # Note that we do not recommend passing credentials in the url. cleaned_url = urlunparse((url_obj.scheme, url_obj.netloc, url_obj.path, "", "", "")) return cleaned_url + + +def get_pyproject_toml_required_deps(pyproject_path: Path) -> list[str]: + """Parse pyproject.toml to extract dependencies.""" + try: + with pyproject_path.open("rb") as f: + pyproject = tomli.load(f) + + # Get dependencies from pyproject.toml + project = pyproject.get("project", {}) + return cast(list[str], project.get("dependencies", [])) + except Exception as e: + logger.error("Error parsing pyproject.toml", error=e) + return [] diff --git a/tracecat/registry/repository.py b/tracecat/registry/repository.py index 7a9591089..1c62e6f04 100644 --- a/tracecat/registry/repository.py +++ b/tracecat/registry/repository.py @@ -248,16 +248,13 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None: raise RegistryError(f"Local git repository not found: {repo_path}") # Check that there's either pyproject.toml or setup.py - if ( - not repo_path.joinpath("pyproject.toml").exists() - and not repo_path.joinpath("setup.py").exists() - ): + if not repo_path.joinpath("pyproject.toml").exists(): # expand the path to the host path if host_path := config.TRACECAT__LOCAL_REPOSITORY_PATH: host_path = Path(host_path).expanduser() logger.debug("Host path", host_path=host_path) raise RegistryError( - "Local repository does not contain pyproject.toml or setup.py. " + "Local repository does not contain pyproject.toml. " "Please ensure TRACECAT__LOCAL_REPOSITORY_PATH points to a valid Python package." f"Host path: {host_path}" )