diff --git a/src/nalgonda/custom_tools/build_directory_tree.py b/src/nalgonda/custom_tools/build_directory_tree.py index 67b8f2d9..87e180e6 100644 --- a/src/nalgonda/custom_tools/build_directory_tree.py +++ b/src/nalgonda/custom_tools/build_directory_tree.py @@ -1,46 +1,41 @@ -import os +from pathlib import Path from agency_swarm import BaseTool -from pydantic import Field +from pydantic import Field, field_validator + +from nalgonda.custom_tools.utils import check_directory_traversal class BuildDirectoryTree(BaseTool): """Print the structure of directories and files.""" - start_directory: str = Field( - default_factory=lambda: os.getcwd(), + start_directory: Path = Field( + default_factory=Path.cwd, description="The starting directory for the tree, defaults to the current working directory.", ) - file_extensions: list[str] | None = Field( - default_factory=lambda: None, - description="List of file extensions to include in the tree. If None, all files will be included.", + file_extensions: set[str] = Field( + default_factory=set, + description="Set of file extensions to include in the tree. If empty, all files will be included.", ) - def run(self) -> str: - """Run the tool.""" - self._validate_start_directory() - tree_str = self.print_tree() - return tree_str + _validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal) - def print_tree(self): - """Recursively print the tree of directories and files using os.walk.""" + def run(self) -> str: + """Recursively print the tree of directories and files using pathlib.""" tree_str = "" + start_path = self.start_directory.resolve() - for root, _, files in os.walk(self.start_directory, topdown=True): - level = root.replace(self.start_directory, "").count(os.sep) + def recurse(directory: Path, level: int = 0) -> None: + nonlocal tree_str indent = " " * 4 * level - tree_str += f"{indent}{os.path.basename(root)}\n" + tree_str += f"{indent}{directory.name}\n" sub_indent = " " * 4 * (level + 1) - for f in files: - if not self.file_extensions or f.endswith(tuple(self.file_extensions)): - tree_str += f"{sub_indent}{f}\n" + for path in sorted(directory.iterdir()): + if path.is_dir(): + recurse(path, level + 1) + elif path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions): + tree_str += f"{sub_indent}{path.name}\n" + recurse(start_path) return tree_str - - def _validate_start_directory(self): - """Do not allow directory traversal.""" - if ".." in self.start_directory or ( - self.start_directory.startswith("/") and not self.start_directory.startswith("/tmp") - ): - raise ValueError("Directory traversal is not allowed.") diff --git a/src/nalgonda/custom_tools/print_all_files_in_directory.py b/src/nalgonda/custom_tools/print_all_files_in_directory.py index a4ebb278..4a31292e 100644 --- a/src/nalgonda/custom_tools/print_all_files_in_directory.py +++ b/src/nalgonda/custom_tools/print_all_files_in_directory.py @@ -1,44 +1,42 @@ -import os +from pathlib import Path from agency_swarm import BaseTool -from pydantic import Field +from pydantic import Field, field_validator + +from nalgonda.custom_tools.utils import check_directory_traversal class PrintAllFilesInDirectory(BaseTool): """Print the contents of all files in a start_directory recursively.""" - start_directory: str = Field( - default_factory=lambda: os.getcwd(), + start_directory: Path = Field( + default_factory=Path.cwd, description="Directory to search for Python files, by default the current working directory.", ) - file_extensions: list[str] | None = Field( - default_factory=lambda: None, - description="List of file extensions to include in the output. If None, all files will be included.", + file_extensions: set[str] = Field( + default_factory=set, + description="Set of file extensions to include in the output. If empty, all files will be included.", ) - def run(self) -> str: - """Run the tool.""" - self._validate_start_directory() + _validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal) + def run(self) -> str: + """ + Recursively searches for files within `start_directory` and compiles their contents into a single string. + """ output = [] - for root, _, files in os.walk(self.start_directory, topdown=True): - for file in files: - if not self.file_extensions or file.endswith(tuple(self.file_extensions)): - file_path = os.path.join(root, file) - output.append(f"{file_path}:\n```\n{self.read_file(file_path)}\n```\n") + start_path = self.start_directory.resolve() + + for path in start_path.rglob("*"): + if path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions): + output.append(f"{str(path)}:\n```\n{self.read_file(path)}\n```\n") + return "\n".join(output) @staticmethod - def read_file(file_path): + def read_file(file_path: Path): + """Read and return the contents of a file.""" try: - with open(file_path, "r") as file: - return file.read() + return file_path.read_text() except IOError as e: return f"Error reading file {file_path}: {e}" - - def _validate_start_directory(self): - """Do not allow directory traversal.""" - if ".." in self.start_directory or ( - self.start_directory.startswith("/") and not self.start_directory.startswith("/tmp") - ): - raise ValueError("Directory traversal is not allowed.") diff --git a/tests/custom_tools/test_build_directory_tree.py b/tests/custom_tools/test_build_directory_tree.py index be263a0a..c59a0853 100644 --- a/tests/custom_tools/test_build_directory_tree.py +++ b/tests/custom_tools/test_build_directory_tree.py @@ -1,4 +1,4 @@ -import os +from pathlib import Path from nalgonda.custom_tools import BuildDirectoryTree @@ -7,7 +7,7 @@ def test_build_directory_tree_with_py_extension(temp_dir): """ Test if BuildDirectoryTree correctly lists only .py files in the directory tree. """ - bdt = BuildDirectoryTree(start_directory=str(temp_dir), file_extensions=[".py"]) + bdt = BuildDirectoryTree(start_directory=temp_dir, file_extensions={".py"}) expected_output = f"{temp_dir.name}\n sub\n test.py\n" assert bdt.run() == expected_output @@ -16,7 +16,7 @@ def test_build_directory_tree_with_multiple_extensions(temp_dir): """ Test if BuildDirectoryTree lists files with multiple specified extensions. """ - bdt = BuildDirectoryTree(start_directory=str(temp_dir), file_extensions=[".py", ".txt"]) + bdt = BuildDirectoryTree(start_directory=temp_dir, file_extensions={".py", ".txt"}) expected_output = { f"{temp_dir.name}", " sub", @@ -32,5 +32,7 @@ def test_build_directory_tree_default_settings(): Test if BuildDirectoryTree uses the correct default settings. """ bdt = BuildDirectoryTree() - assert bdt.start_directory == os.getcwd() - assert bdt.file_extensions is None + assert bdt.start_directory == Path.cwd() + assert bdt.file_extensions == set() + + diff --git a/tests/custom_tools/test_print_all_files_in_directory.py b/tests/custom_tools/test_print_all_files_in_directory.py index 9cf9641a..14bddfaf 100644 --- a/tests/custom_tools/test_print_all_files_in_directory.py +++ b/tests/custom_tools/test_print_all_files_in_directory.py @@ -1,5 +1,3 @@ -import os - from nalgonda.custom_tools import PrintAllFilesInDirectory @@ -7,7 +5,7 @@ def test_print_all_files_no_extension_filter(temp_dir): """ Test if PrintAllFilesInDirectory correctly prints contents of all files when no file extension filter is applied. """ - pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir)) + pafid = PrintAllFilesInDirectory(start_directory=temp_dir) expected_output = { f"{temp_dir}/sub/test.py:\n```\nprint('hello')\n```", f"{temp_dir}/sub/test.txt:\n```\nhello world\n```", @@ -20,8 +18,8 @@ def test_print_all_files_with_py_extension(temp_dir): """ Test if PrintAllFilesInDirectory correctly prints contents of .py files only. """ - pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir), file_extensions=[".py"]) - expected_output = f"{os.path.join(temp_dir, 'sub', 'test.py')}:\n```\nprint('hello')\n```\n" + pafid = PrintAllFilesInDirectory(start_directory=temp_dir, file_extensions={".py"}) + expected_output = f"{temp_dir.joinpath('sub', 'test.py')}:\n```\nprint('hello')\n```\n" assert pafid.run() == expected_output @@ -29,8 +27,8 @@ def test_print_all_files_with_txt_extension(temp_dir): """ Test if PrintAllFilesInDirectory correctly prints contents of .txt files only. """ - pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir), file_extensions=[".txt"]) - expected_output = f"{os.path.join(temp_dir, 'sub', 'test.txt')}:\n```\nhello world\n```\n" + pafid = PrintAllFilesInDirectory(start_directory=temp_dir, file_extensions={".txt"}) + expected_output = f"{temp_dir.joinpath('sub', 'test.txt')}:\n```\nhello world\n```\n" assert pafid.run() == expected_output @@ -39,12 +37,11 @@ def test_print_all_files_error_reading_file(temp_dir): Test if PrintAllFilesInDirectory handles errors while reading a file. """ # Create an unreadable file - unreadable_file = os.path.join(temp_dir, "unreadable_file.txt") - with open(unreadable_file, "w") as f: - f.write("content") - os.chmod(unreadable_file, 0o000) # make the file unreadable + unreadable_file = temp_dir.joinpath("unreadable_file.txt") + unreadable_file.write_text("content") + unreadable_file.chmod(0o000) # make the file unreadable - pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir), file_extensions=[".txt"]) + pafid = PrintAllFilesInDirectory(start_directory=temp_dir, file_extensions={".txt"}) assert "Error reading file" in pafid.run() - os.chmod(unreadable_file, 0o644) # reset file permissions for cleanup + unreadable_file.chmod(0o644) # reset file permissions for cleanup