diff --git a/python/tach/cli.py b/python/tach/cli.py index 3b76da72..3f0c06d9 100644 --- a/python/tach/cli.py +++ b/python/tach/cli.py @@ -8,9 +8,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from tach import __version__, cache, icons +from tach import __version__, cache, extension, icons from tach import filesystem as fs -from tach.check_external import check_external from tach.colors import BCOLORS from tach.constants import CONFIG_FILE_NAME, TOOL_NAME from tach.errors import ( @@ -21,17 +20,7 @@ TachSetupError, TachVisibilityError, ) -from tach.extension import ( - ProjectConfig, - check, - check_computation_cache, - create_computation_cache_key, - detect_unused_dependencies, - format_diagnostics, - run_server, - serialize_diagnostics_json, - update_computation_cache, -) +from tach.extension import ProjectConfig from tach.filesystem import install_pre_commit from tach.logging import CallInfo, init_logging, logger from tach.modularity import export_report, upload_report_to_gauge @@ -458,7 +447,7 @@ def replay(self): def check_cache_for_action( project_root: Path, project_config: ProjectConfig, action: str ) -> CachedOutput: - cache_key = create_computation_cache_key( + cache_key = extension.create_computation_cache_key( project_root=str(project_root), source_roots=[ str(project_root / source_root) @@ -470,7 +459,7 @@ def check_cache_for_action( env_dependencies=project_config.cache.env_dependencies, backend=project_config.cache.backend, ) - cache_result = check_computation_cache( + cache_result = extension.check_computation_cache( project_root=str(project_root), cache_key=cache_key ) if cache_result: @@ -503,7 +492,7 @@ def tach_check( try: exact |= project_config.exact - diagnostics = check( + diagnostics = extension.check( project_root=project_root, project_config=project_config, dependencies=dependencies, @@ -514,21 +503,25 @@ def tach_check( if output_format == "json": try: - print(serialize_diagnostics_json(diagnostics, pretty_print=True)) + print( + extension.serialize_diagnostics_json(diagnostics, pretty_print=True) + ) except ValueError as e: json.dump({"error": str(e)}, sys.stdout) sys.exit(1 if has_errors else 0) if diagnostics: print( - format_diagnostics(project_root=project_root, diagnostics=diagnostics), + extension.format_diagnostics( + project_root=project_root, diagnostics=diagnostics + ), file=sys.stderr, ) exit_code = 1 if has_errors else 0 # If we're checking in exact mode, we want to verify that there are no unused dependencies if dependencies and exact: - unused_dependencies = detect_unused_dependencies( + unused_dependencies = extension.detect_unused_dependencies( project_root=project_root, project_config=project_config, exclude_paths=exclude_paths, @@ -569,7 +562,7 @@ def tach_check_external( }, ) try: - diagnostics = check_external( + diagnostics = extension.check_external_dependencies( project_root=project_root, project_config=project_config, exclude_paths=exclude_paths, @@ -577,7 +570,9 @@ def tach_check_external( if diagnostics: print( - format_diagnostics(project_root=project_root, diagnostics=diagnostics), + extension.format_diagnostics( + project_root=project_root, diagnostics=diagnostics + ), file=sys.stderr, ) @@ -907,7 +902,7 @@ def tach_test( ) if results.tests_ran_to_completion: - update_computation_cache( + extension.update_computation_cache( str(project_root), cache_key=cached_output.key, value=( @@ -1008,7 +1003,7 @@ def tach_server( sys.exit(1) try: - run_server(project_root, project_config) + extension.run_server(project_root, project_config) except TachSetupError as e: print(f"Failed to setup LSP server: {e}") sys.exit(1) diff --git a/python/tests/test_cli.py b/python/tests/test_cli.py index 6ed04640..10309b1f 100644 --- a/python/tests/test_cli.py +++ b/python/tests/test_cli.py @@ -1,5 +1,6 @@ from __future__ import annotations +import pathlib from pathlib import Path from unittest.mock import Mock @@ -8,6 +9,8 @@ from tach import cli from tach.extension import ProjectConfig +_VALID_TACH_TOML = pathlib.Path(__file__).parent / "example" / "valid" / "tach.toml" + @pytest.fixture def mock_check(mocker) -> Mock: @@ -64,3 +67,15 @@ def test_execute_with_valid_exclude(capfd, mock_check, mock_project_config): assert sys_exit.value.code == 0 assert "✅" in captured.out assert "All modules validated!" in captured.out + + +def test_tach_server_with_config(mocker): + mock_run_server = mocker.patch("tach.extension.run_server", autospec=True) + cli.tach_server( + project_root=Path(), + project_config=ProjectConfig(), + config_path=_VALID_TACH_TOML, + ) + # Verify server was run with the custom config. + mock_run_server.assert_called_once() + assert "domain_four.py" in mock_run_server.call_args[0][1].exclude