diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index d3f23ccee..766a5c457 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -3,6 +3,7 @@ import asyncio import traceback from collections.abc import Iterator, Mapping +from pathlib import Path from typing import Any, cast import ray @@ -33,7 +34,7 @@ ) from tracecat.git import prepare_git_url from tracecat.logger import logger -from tracecat.parse import traverse_leaves +from tracecat.parse import get_pyproject_toml_required_deps, traverse_leaves from tracecat.registry.actions.models import BoundRegistryAction from tracecat.registry.actions.service import RegistryActionsService from tracecat.secrets.common import apply_masks_object @@ -295,11 +296,39 @@ async def run_action_on_ray_cluster( additional_vars: dict[str, Any] = {} # Add git URL to pip dependencies if SHA is present + pip_deps = [] if ctx.git_url and ctx.git_url.ref: url = ctx.git_url.to_url() - additional_vars["pip"] = [url] + pip_deps.append(url) logger.trace("Adding git URL to runtime env", git_url=ctx.git_url, url=url) + # If we have a local registry, we need to add it to the runtime env + if config.TRACECAT__LOCAL_REPOSITORY_ENABLED: + local_repo_path = config.TRACECAT__LOCAL_REPOSITORY_CONTAINER_PATH + logger.info( + "Adding local repository and required dependencies to runtime env", + local_repo_path=local_repo_path, + ) + + # Try pyproject.toml first + pyproject_path = Path(local_repo_path) / "pyproject.toml" + if not pyproject_path.exists(): + logger.error( + "No pyproject.toml found in local repository", path=pyproject_path + ) + raise ValueError("No pyproject.toml found in local repository") + required_deps = await asyncio.to_thread( + get_pyproject_toml_required_deps, pyproject_path + ) + logger.debug( + "Found pyproject.toml with required dependencies", deps=required_deps + ) + pip_deps.extend([local_repo_path, *required_deps]) + + # Add pip dependencies to runtime env + if pip_deps: + additional_vars["pip"] = pip_deps + runtime_env = RuntimeEnv(env_vars=env_vars, **additional_vars) logger.info("Running action on ray cluster", runtime_env=runtime_env)