diff --git a/python/tests/test_report.py b/python/tests/test_report.py index be3ccdba..2257579c 100644 --- a/python/tests/test_report.py +++ b/python/tests/test_report.py @@ -7,74 +7,123 @@ from tach.errors import TachError from tach.extension import ProjectConfig +from tach.parsing.config import parse_project_config from tach.report import report @pytest.fixture -def mock_project_config() -> ProjectConfig: +def empty_config() -> ProjectConfig: return ProjectConfig() @pytest.fixture -def mock_cwd(tmp_path): - try: - original_path = os.getcwd() - os.chdir(tmp_path) - yield tmp_path - finally: - os.chdir(original_path) +def tmp_project(tmp_path): + """Create a temporary project directory structure""" + project_dir = tmp_path / "project" + project_dir.mkdir() + original_cwd = os.getcwd() + os.chdir(project_dir) + yield project_dir + os.chdir(original_cwd) -# The code assumes that the cwd is within the project root -# due to pre-condition checks +@pytest.fixture +def example_valid_dir(example_dir): + original_cwd = os.getcwd() + os.chdir(example_dir / "valid") + yield example_dir / "valid" + os.chdir(original_cwd) -def test_valid_path(mock_project_config, mock_cwd): - mock_path = mock_cwd / "test.py" - mock_path.touch() +def test_valid_file(empty_config, tmp_project): + test_file = tmp_project / "test.py" + test_file.touch() result = report( - project_root=mock_cwd, + project_root=tmp_project, path=Path("test.py"), - project_config=mock_project_config, + project_config=empty_config, ) assert result -def test_valid_dir(mock_project_config, mock_cwd): - mock_path = mock_cwd / "test" - mock_path.mkdir() +def test_valid_directory(empty_config, tmp_project): + test_dir = tmp_project / "test" + test_dir.mkdir() result = report( - project_root=mock_cwd, + project_root=tmp_project, path=Path("test"), - project_config=mock_project_config, + project_config=empty_config, ) assert result -def test_valid_dir_trailing_slash(mock_project_config, mock_cwd): - mock_path = mock_cwd / "test" - mock_path.mkdir() +def test_valid_directory_trailing_slash(empty_config, tmp_project): + test_dir = tmp_project / "test" + test_dir.mkdir() result = report( - project_root=mock_cwd, + project_root=tmp_project, path=Path("test/"), - project_config=mock_project_config, + project_config=empty_config, ) assert result -def test_invalid_root(mock_project_config, mock_cwd): +def test_invalid_project_root(empty_config, tmp_project): with pytest.raises(TachError): report( project_root=Path("Invalid!!"), path=Path("."), - project_config=mock_project_config, + project_config=empty_config, ) -def test_invalid_path(mock_project_config, mock_cwd): +def test_invalid_path(empty_config, tmp_project): with pytest.raises(TachError): report( - project_root=mock_cwd, + project_root=tmp_project, path=Path("Invalid!!"), - project_config=mock_project_config, + project_config=empty_config, ) + + +def test_report_valid_domain_one(example_valid_dir): + project_config = parse_project_config(example_valid_dir) + result = report( + project_root=example_valid_dir, + path=Path("domain_one"), + project_config=project_config, + ) + + dependencies, usages = result.split("Usages of 'domain_one'") + assert "domain_two.x" in dependencies + assert "domain_one.x" in usages + + +def test_report_valid_domain_two(example_valid_dir): + project_config = parse_project_config(example_valid_dir) + result = report( + project_root=example_valid_dir, + path=Path("domain_two"), + project_config=project_config, + ) + + dependencies, usages = result.split("Usages of 'domain_two'") + assert "domain_four.ok" in dependencies + assert "domain_three.x" in dependencies + assert "domain_two.x" in usages + + +def test_report_raw_output(example_valid_dir): + project_config = parse_project_config(example_valid_dir) + result = report( + project_root=example_valid_dir, + path=Path("domain_one"), + project_config=project_config, + raw=True, + ) + assert result.strip() == ( + """# Module Dependencies +domain_two +# Module Usages +.""" + ) diff --git a/src/commands/report.rs b/src/commands/report.rs index 7dfb3735..b23b00ab 100644 --- a/src/commands/report.rs +++ b/src/commands/report.rs @@ -3,6 +3,8 @@ use std::fmt::Debug; use std::io; use std::path::{Path, PathBuf}; +use rayon::prelude::*; + use thiserror::Error; use crate::colors::*; @@ -14,6 +16,7 @@ use crate::filesystem::{ file_to_module_path, validate_project_modules, walk_pyfiles, FileSystemError, }; use crate::imports::{get_project_imports, ImportParseError, NormalizedImport}; +use crate::interrupt::check_interrupt; use crate::modules::{build_module_tree, error::ModuleTreeError}; struct Dependency { @@ -36,6 +39,8 @@ pub enum ReportCreationError { NothingToReport, #[error("Module tree build error: {0}")] ModuleTree(#[from] ModuleTreeError), + #[error("Operation interrupted")] + Interrupted, } pub type Result = std::result::Result; @@ -217,6 +222,9 @@ pub fn create_dependency_report( &source_roots, project_config.all_modules().cloned().collect(), ); + + check_interrupt().map_err(|_| ReportCreationError::Interrupted)?; + let module_tree = build_module_tree( &source_roots, &valid_modules, @@ -232,91 +240,120 @@ pub fn create_dependency_report( let mut report = DependencyReport::new(path.display().to_string()); - for pyfile in walk_pyfiles(project_root.to_str().unwrap()) { - let absolute_pyfile = project_root.join(&pyfile); - let file_module_path = file_to_module_path(&source_roots, &absolute_pyfile)?; - let file_module = module_tree.find_nearest(&file_module_path); - - match get_project_imports( - &source_roots, - &absolute_pyfile, - project_config.ignore_type_checking_imports, - project_config.include_string_imports, - ) { - Ok(project_imports) => { - let is_in_target_path = is_module_prefix(&module_path, &file_module_path); - - if is_in_target_path && !skip_dependencies { - // Add external dependencies - report.dependencies.extend( - project_imports - .imports - .into_iter() - .filter_map(|import| { - if let Some(import_module) = - module_tree.find_nearest(&import.module_path) - { - if import_module == target_module { - return None; // Skip internal imports - } - - // Check if module is in include list - include_dependency_modules.as_ref().map_or( - Some((import.clone(), import_module.clone())), - |included_modules| { - if included_modules.contains(&import_module.full_path) { - Some((import.clone(), import_module.clone())) - } else { - None + for source_root in &source_roots { + check_interrupt().map_err(|_| ReportCreationError::Interrupted)?; + + let source_root_results: Vec<_> = walk_pyfiles(&source_root.display().to_string()) + .par_bridge() + .filter_map(|pyfile| { + if check_interrupt().is_err() { + return None; + } + + let absolute_pyfile = source_root.join(&pyfile); + let file_module_path = match file_to_module_path(&source_roots, &absolute_pyfile) { + Ok(path) => path, + Err(_) => return None, + }; + let file_module = module_tree.find_nearest(&file_module_path); + + match get_project_imports( + &source_roots, + &absolute_pyfile, + project_config.ignore_type_checking_imports, + project_config.include_string_imports, + ) { + Ok(project_imports) => { + let is_in_target_path = is_module_prefix(&module_path, &file_module_path); + let mut dependencies = Vec::new(); + let mut usages = Vec::new(); + + if is_in_target_path && !skip_dependencies { + // Add dependencies + dependencies.extend( + project_imports + .imports + .iter() + .filter_map(|import| { + if let Some(import_module) = + module_tree.find_nearest(&import.module_path) + { + if import_module == target_module { + return None; } - }, - ) - } else { - None // Skip imports that don't match any module - } - }) - .map(|(import, import_module)| Dependency { - file_path: pyfile.clone(), - absolute_path: absolute_pyfile.clone(), - import, - source_module: target_module.full_path.clone(), - target_module: import_module.full_path.clone(), - }), - ); - } else if !is_in_target_path && !skip_usages { - // Add external usages - report.usages.extend( - project_imports - .imports - .into_iter() - .filter(|import| { - if !is_module_prefix(&module_path, &import.module_path) { - return false; // Skip imports not targeting our path - } - - // Check if using module is in include list - file_module.as_ref().map_or(false, |m| { - include_usage_modules - .as_ref() - .map_or(true, |included_modules| { - included_modules.contains(&m.full_path) + include_dependency_modules.as_ref().map_or( + Some((import.clone(), import_module.clone())), + |included_modules| { + if included_modules + .contains(&import_module.full_path) + { + Some(( + import.clone(), + import_module.clone(), + )) + } else { + None + } + }, + ) + } else { + None + } + }) + .map(|(import, import_module)| Dependency { + file_path: pyfile.clone(), + absolute_path: absolute_pyfile.clone(), + import, + source_module: target_module.full_path.clone(), + target_module: import_module.full_path.clone(), + }), + ); + } else if !is_in_target_path && !skip_usages { + // Add usages + usages.extend( + project_imports + .imports + .iter() + .filter(|import| { + if !is_module_prefix(&module_path, &import.module_path) { + return false; + } + file_module.as_ref().map_or(false, |m| { + include_usage_modules.as_ref().map_or( + true, + |included_modules| { + included_modules.contains(&m.full_path) + }, + ) }) - }) - }) - .map(|import| Dependency { - file_path: pyfile.clone(), - absolute_path: absolute_pyfile.clone(), - import, - source_module: file_module - .as_ref() - .map_or(String::new(), |m| m.full_path.clone()), - target_module: target_module.full_path.clone(), - }), - ); + }) + .map(|import| Dependency { + file_path: pyfile.clone(), + absolute_path: absolute_pyfile.clone(), + import: import.clone(), + source_module: file_module + .as_ref() + .map_or(String::new(), |m| m.full_path.clone()), + target_module: target_module.full_path.clone(), + }), + ); + } + + Some((dependencies, usages, None)) + } + Err(err) => Some((Vec::new(), Vec::new(), Some(err.to_string()))), } - } - Err(err) => { - report.warnings.push(err.to_string()); + }) + .collect(); + + check_interrupt().map_err(|_| ReportCreationError::Interrupted)?; + + // Combine results + for (dependencies, usages, warning) in source_root_results { + report.dependencies.extend(dependencies); + report.usages.extend(usages); + if let Some(warning) = warning { + report.warnings.push(warning); } } }