Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
max-muoto committed Feb 3, 2025
1 parent 8fa9b37 commit 2b2ba63
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
41 changes: 18 additions & 23 deletions python/tach/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any

from tach import __version__, cache, icons
from tach import __version__, cache, extension, icons
from tach import filesystem as fs
from tach.check_external import check_external
from tach.colors import BCOLORS
from tach.constants import CONFIG_FILE_NAME, TOOL_NAME
from tach.errors import (
Expand All @@ -21,17 +20,7 @@
TachSetupError,
TachVisibilityError,
)
from tach.extension import (
ProjectConfig,
check,
check_computation_cache,
create_computation_cache_key,
detect_unused_dependencies,
format_diagnostics,
run_server,
serialize_diagnostics_json,
update_computation_cache,
)
from tach.extension import ProjectConfig
from tach.filesystem import install_pre_commit
from tach.logging import CallInfo, init_logging, logger
from tach.modularity import export_report, upload_report_to_gauge
Expand Down Expand Up @@ -458,7 +447,7 @@ def replay(self):
def check_cache_for_action(
project_root: Path, project_config: ProjectConfig, action: str
) -> CachedOutput:
cache_key = create_computation_cache_key(
cache_key = extension.create_computation_cache_key(
project_root=str(project_root),
source_roots=[
str(project_root / source_root)
Expand All @@ -470,7 +459,7 @@ def check_cache_for_action(
env_dependencies=project_config.cache.env_dependencies,
backend=project_config.cache.backend,
)
cache_result = check_computation_cache(
cache_result = extension.check_computation_cache(
project_root=str(project_root), cache_key=cache_key
)
if cache_result:
Expand Down Expand Up @@ -503,7 +492,7 @@ def tach_check(
try:
exact |= project_config.exact

diagnostics = check(
diagnostics = extension.check(
project_root=project_root,
project_config=project_config,
dependencies=dependencies,
Expand All @@ -514,21 +503,25 @@ def tach_check(

if output_format == "json":
try:
print(serialize_diagnostics_json(diagnostics, pretty_print=True))
print(
extension.serialize_diagnostics_json(diagnostics, pretty_print=True)
)
except ValueError as e:
json.dump({"error": str(e)}, sys.stdout)
sys.exit(1 if has_errors else 0)

if diagnostics:
print(
format_diagnostics(project_root=project_root, diagnostics=diagnostics),
extension.format_diagnostics(
project_root=project_root, diagnostics=diagnostics
),
file=sys.stderr,
)
exit_code = 1 if has_errors else 0

# If we're checking in exact mode, we want to verify that there are no unused dependencies
if dependencies and exact:
unused_dependencies = detect_unused_dependencies(
unused_dependencies = extension.detect_unused_dependencies(
project_root=project_root,
project_config=project_config,
exclude_paths=exclude_paths,
Expand Down Expand Up @@ -569,15 +562,17 @@ def tach_check_external(
},
)
try:
diagnostics = check_external(
diagnostics = extension.check_external_dependencies(
project_root=project_root,
project_config=project_config,
exclude_paths=exclude_paths,
)

if diagnostics:
print(
format_diagnostics(project_root=project_root, diagnostics=diagnostics),
extension.format_diagnostics(
project_root=project_root, diagnostics=diagnostics
),
file=sys.stderr,
)

Expand Down Expand Up @@ -907,7 +902,7 @@ def tach_test(
)

if results.tests_ran_to_completion:
update_computation_cache(
extension.update_computation_cache(
str(project_root),
cache_key=cached_output.key,
value=(
Expand Down Expand Up @@ -1008,7 +1003,7 @@ def tach_server(
sys.exit(1)

try:
run_server(project_root, project_config)
extension.run_server(project_root, project_config)
except TachSetupError as e:
print(f"Failed to setup LSP server: {e}")
sys.exit(1)
Expand Down
15 changes: 15 additions & 0 deletions python/tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import pathlib
from pathlib import Path
from unittest.mock import Mock

Expand All @@ -8,6 +9,8 @@
from tach import cli
from tach.extension import ProjectConfig

_VALID_TACH_TOML = pathlib.Path(__file__).parent / "example" / "valid" / "tach.toml"


@pytest.fixture
def mock_check(mocker) -> Mock:
Expand Down Expand Up @@ -64,3 +67,15 @@ def test_execute_with_valid_exclude(capfd, mock_check, mock_project_config):
assert sys_exit.value.code == 0
assert "✅" in captured.out
assert "All modules validated!" in captured.out


def test_tach_server_with_config(mocker):
mock_run_server = mocker.patch("tach.extension.run_server", autospec=True)
cli.tach_server(
project_root=Path(),
project_config=ProjectConfig(),
config_path=_VALID_TACH_TOML,
)
# Verify server was run with the custom config.
mock_run_server.assert_called_once()
assert "domain_four.py" in mock_run_server.call_args[0][1].exclude

0 comments on commit 2b2ba63

Please sign in to comment.