-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: use field_validator; use pathlib
- Loading branch information
1 parent
0aa926a
commit 6a0a47c
Showing
4 changed files
with
62 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,41 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
from agency_swarm import BaseTool | ||
from pydantic import Field | ||
from pydantic import Field, field_validator | ||
|
||
from nalgonda.custom_tools.utils import check_directory_traversal | ||
|
||
|
||
class BuildDirectoryTree(BaseTool): | ||
"""Print the structure of directories and files.""" | ||
|
||
start_directory: str = Field( | ||
default_factory=lambda: os.getcwd(), | ||
start_directory: Path = Field( | ||
default_factory=Path.cwd, | ||
description="The starting directory for the tree, defaults to the current working directory.", | ||
) | ||
file_extensions: list[str] | None = Field( | ||
default_factory=lambda: None, | ||
description="List of file extensions to include in the tree. If None, all files will be included.", | ||
file_extensions: set[str] = Field( | ||
default_factory=set, | ||
description="Set of file extensions to include in the tree. If empty, all files will be included.", | ||
) | ||
|
||
def run(self) -> str: | ||
"""Run the tool.""" | ||
self._validate_start_directory() | ||
tree_str = self.print_tree() | ||
return tree_str | ||
_validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal) | ||
|
||
def print_tree(self): | ||
"""Recursively print the tree of directories and files using os.walk.""" | ||
def run(self) -> str: | ||
"""Recursively print the tree of directories and files using pathlib.""" | ||
tree_str = "" | ||
start_path = self.start_directory.resolve() | ||
|
||
for root, _, files in os.walk(self.start_directory, topdown=True): | ||
level = root.replace(self.start_directory, "").count(os.sep) | ||
def recurse(directory: Path, level: int = 0) -> None: | ||
nonlocal tree_str | ||
indent = " " * 4 * level | ||
tree_str += f"{indent}{os.path.basename(root)}\n" | ||
tree_str += f"{indent}{directory.name}\n" | ||
sub_indent = " " * 4 * (level + 1) | ||
|
||
for f in files: | ||
if not self.file_extensions or f.endswith(tuple(self.file_extensions)): | ||
tree_str += f"{sub_indent}{f}\n" | ||
for path in sorted(directory.iterdir()): | ||
if path.is_dir(): | ||
recurse(path, level + 1) | ||
elif path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions): | ||
tree_str += f"{sub_indent}{path.name}\n" | ||
|
||
recurse(start_path) | ||
return tree_str | ||
|
||
def _validate_start_directory(self): | ||
"""Do not allow directory traversal.""" | ||
if ".." in self.start_directory or ( | ||
self.start_directory.startswith("/") and not self.start_directory.startswith("/tmp") | ||
): | ||
raise ValueError("Directory traversal is not allowed.") |
48 changes: 23 additions & 25 deletions
48
src/nalgonda/custom_tools/print_all_files_in_directory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,42 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
from agency_swarm import BaseTool | ||
from pydantic import Field | ||
from pydantic import Field, field_validator | ||
|
||
from nalgonda.custom_tools.utils import check_directory_traversal | ||
|
||
|
||
class PrintAllFilesInDirectory(BaseTool): | ||
"""Print the contents of all files in a start_directory recursively.""" | ||
|
||
start_directory: str = Field( | ||
default_factory=lambda: os.getcwd(), | ||
start_directory: Path = Field( | ||
default_factory=Path.cwd, | ||
description="Directory to search for Python files, by default the current working directory.", | ||
) | ||
file_extensions: list[str] | None = Field( | ||
default_factory=lambda: None, | ||
description="List of file extensions to include in the output. If None, all files will be included.", | ||
file_extensions: set[str] = Field( | ||
default_factory=set, | ||
description="Set of file extensions to include in the output. If empty, all files will be included.", | ||
) | ||
|
||
def run(self) -> str: | ||
"""Run the tool.""" | ||
self._validate_start_directory() | ||
_validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal) | ||
|
||
def run(self) -> str: | ||
""" | ||
Recursively searches for files within `start_directory` and compiles their contents into a single string. | ||
""" | ||
output = [] | ||
for root, _, files in os.walk(self.start_directory, topdown=True): | ||
for file in files: | ||
if not self.file_extensions or file.endswith(tuple(self.file_extensions)): | ||
file_path = os.path.join(root, file) | ||
output.append(f"{file_path}:\n```\n{self.read_file(file_path)}\n```\n") | ||
start_path = self.start_directory.resolve() | ||
|
||
for path in start_path.rglob("*"): | ||
if path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions): | ||
output.append(f"{str(path)}:\n```\n{self.read_file(path)}\n```\n") | ||
|
||
return "\n".join(output) | ||
|
||
@staticmethod | ||
def read_file(file_path): | ||
def read_file(file_path: Path): | ||
"""Read and return the contents of a file.""" | ||
try: | ||
with open(file_path, "r") as file: | ||
return file.read() | ||
return file_path.read_text() | ||
except IOError as e: | ||
return f"Error reading file {file_path}: {e}" | ||
|
||
def _validate_start_directory(self): | ||
"""Do not allow directory traversal.""" | ||
if ".." in self.start_directory or ( | ||
self.start_directory.startswith("/") and not self.start_directory.startswith("/tmp") | ||
): | ||
raise ValueError("Directory traversal is not allowed.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters