Skip to content

Commit

Permalink
chore: use field_validator; use pathlib
Browse files Browse the repository at this point in the history
  • Loading branch information
guiparpinelli committed Dec 13, 2023
1 parent 0aa926a commit 6a0a47c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 70 deletions.
49 changes: 22 additions & 27 deletions src/nalgonda/custom_tools/build_directory_tree.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,41 @@
import os
from pathlib import Path

from agency_swarm import BaseTool
from pydantic import Field
from pydantic import Field, field_validator

from nalgonda.custom_tools.utils import check_directory_traversal


class BuildDirectoryTree(BaseTool):
"""Print the structure of directories and files."""

start_directory: str = Field(
default_factory=lambda: os.getcwd(),
start_directory: Path = Field(
default_factory=Path.cwd,
description="The starting directory for the tree, defaults to the current working directory.",
)
file_extensions: list[str] | None = Field(
default_factory=lambda: None,
description="List of file extensions to include in the tree. If None, all files will be included.",
file_extensions: set[str] = Field(
default_factory=set,
description="Set of file extensions to include in the tree. If empty, all files will be included.",
)

def run(self) -> str:
"""Run the tool."""
self._validate_start_directory()
tree_str = self.print_tree()
return tree_str
_validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal)

def print_tree(self):
"""Recursively print the tree of directories and files using os.walk."""
def run(self) -> str:
"""Recursively print the tree of directories and files using pathlib."""
tree_str = ""
start_path = self.start_directory.resolve()

for root, _, files in os.walk(self.start_directory, topdown=True):
level = root.replace(self.start_directory, "").count(os.sep)
def recurse(directory: Path, level: int = 0) -> None:
nonlocal tree_str
indent = " " * 4 * level
tree_str += f"{indent}{os.path.basename(root)}\n"
tree_str += f"{indent}{directory.name}\n"
sub_indent = " " * 4 * (level + 1)

for f in files:
if not self.file_extensions or f.endswith(tuple(self.file_extensions)):
tree_str += f"{sub_indent}{f}\n"
for path in sorted(directory.iterdir()):
if path.is_dir():
recurse(path, level + 1)
elif path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions):
tree_str += f"{sub_indent}{path.name}\n"

recurse(start_path)
return tree_str

def _validate_start_directory(self):
"""Do not allow directory traversal."""
if ".." in self.start_directory or (
self.start_directory.startswith("/") and not self.start_directory.startswith("/tmp")
):
raise ValueError("Directory traversal is not allowed.")
48 changes: 23 additions & 25 deletions src/nalgonda/custom_tools/print_all_files_in_directory.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,42 @@
import os
from pathlib import Path

from agency_swarm import BaseTool
from pydantic import Field
from pydantic import Field, field_validator

from nalgonda.custom_tools.utils import check_directory_traversal


class PrintAllFilesInDirectory(BaseTool):
"""Print the contents of all files in a start_directory recursively."""

start_directory: str = Field(
default_factory=lambda: os.getcwd(),
start_directory: Path = Field(
default_factory=Path.cwd,
description="Directory to search for Python files, by default the current working directory.",
)
file_extensions: list[str] | None = Field(
default_factory=lambda: None,
description="List of file extensions to include in the output. If None, all files will be included.",
file_extensions: set[str] = Field(
default_factory=set,
description="Set of file extensions to include in the output. If empty, all files will be included.",
)

def run(self) -> str:
"""Run the tool."""
self._validate_start_directory()
_validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal)

def run(self) -> str:
"""
Recursively searches for files within `start_directory` and compiles their contents into a single string.
"""
output = []
for root, _, files in os.walk(self.start_directory, topdown=True):
for file in files:
if not self.file_extensions or file.endswith(tuple(self.file_extensions)):
file_path = os.path.join(root, file)
output.append(f"{file_path}:\n```\n{self.read_file(file_path)}\n```\n")
start_path = self.start_directory.resolve()

for path in start_path.rglob("*"):
if path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions):
output.append(f"{str(path)}:\n```\n{self.read_file(path)}\n```\n")

return "\n".join(output)

@staticmethod
def read_file(file_path):
def read_file(file_path: Path):
"""Read and return the contents of a file."""
try:
with open(file_path, "r") as file:
return file.read()
return file_path.read_text()
except IOError as e:
return f"Error reading file {file_path}: {e}"

def _validate_start_directory(self):
"""Do not allow directory traversal."""
if ".." in self.start_directory or (
self.start_directory.startswith("/") and not self.start_directory.startswith("/tmp")
):
raise ValueError("Directory traversal is not allowed.")
12 changes: 7 additions & 5 deletions tests/custom_tools/test_build_directory_tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
from pathlib import Path

from nalgonda.custom_tools import BuildDirectoryTree

Expand All @@ -7,7 +7,7 @@ def test_build_directory_tree_with_py_extension(temp_dir):
"""
Test if BuildDirectoryTree correctly lists only .py files in the directory tree.
"""
bdt = BuildDirectoryTree(start_directory=str(temp_dir), file_extensions=[".py"])
bdt = BuildDirectoryTree(start_directory=temp_dir, file_extensions={".py"})
expected_output = f"{temp_dir.name}\n sub\n test.py\n"
assert bdt.run() == expected_output

Expand All @@ -16,7 +16,7 @@ def test_build_directory_tree_with_multiple_extensions(temp_dir):
"""
Test if BuildDirectoryTree lists files with multiple specified extensions.
"""
bdt = BuildDirectoryTree(start_directory=str(temp_dir), file_extensions=[".py", ".txt"])
bdt = BuildDirectoryTree(start_directory=temp_dir, file_extensions={".py", ".txt"})
expected_output = {
f"{temp_dir.name}",
" sub",
Expand All @@ -32,5 +32,7 @@ def test_build_directory_tree_default_settings():
Test if BuildDirectoryTree uses the correct default settings.
"""
bdt = BuildDirectoryTree()
assert bdt.start_directory == os.getcwd()
assert bdt.file_extensions is None
assert bdt.start_directory == Path.cwd()
assert bdt.file_extensions == set()


23 changes: 10 additions & 13 deletions tests/custom_tools/test_print_all_files_in_directory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os

from nalgonda.custom_tools import PrintAllFilesInDirectory


def test_print_all_files_no_extension_filter(temp_dir):
"""
Test if PrintAllFilesInDirectory correctly prints contents of all files when no file extension filter is applied.
"""
pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir))
pafid = PrintAllFilesInDirectory(start_directory=temp_dir)
expected_output = {
f"{temp_dir}/sub/test.py:\n```\nprint('hello')\n```",
f"{temp_dir}/sub/test.txt:\n```\nhello world\n```",
Expand All @@ -20,17 +18,17 @@ def test_print_all_files_with_py_extension(temp_dir):
"""
Test if PrintAllFilesInDirectory correctly prints contents of .py files only.
"""
pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir), file_extensions=[".py"])
expected_output = f"{os.path.join(temp_dir, 'sub', 'test.py')}:\n```\nprint('hello')\n```\n"
pafid = PrintAllFilesInDirectory(start_directory=temp_dir, file_extensions={".py"})
expected_output = f"{temp_dir.joinpath('sub', 'test.py')}:\n```\nprint('hello')\n```\n"
assert pafid.run() == expected_output


def test_print_all_files_with_txt_extension(temp_dir):
"""
Test if PrintAllFilesInDirectory correctly prints contents of .txt files only.
"""
pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir), file_extensions=[".txt"])
expected_output = f"{os.path.join(temp_dir, 'sub', 'test.txt')}:\n```\nhello world\n```\n"
pafid = PrintAllFilesInDirectory(start_directory=temp_dir, file_extensions={".txt"})
expected_output = f"{temp_dir.joinpath('sub', 'test.txt')}:\n```\nhello world\n```\n"
assert pafid.run() == expected_output


Expand All @@ -39,12 +37,11 @@ def test_print_all_files_error_reading_file(temp_dir):
Test if PrintAllFilesInDirectory handles errors while reading a file.
"""
# Create an unreadable file
unreadable_file = os.path.join(temp_dir, "unreadable_file.txt")
with open(unreadable_file, "w") as f:
f.write("content")
os.chmod(unreadable_file, 0o000) # make the file unreadable
unreadable_file = temp_dir.joinpath("unreadable_file.txt")
unreadable_file.write_text("content")
unreadable_file.chmod(0o000) # make the file unreadable

pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir), file_extensions=[".txt"])
pafid = PrintAllFilesInDirectory(start_directory=temp_dir, file_extensions={".txt"})
assert "Error reading file" in pafid.run()

os.chmod(unreadable_file, 0o644) # reset file permissions for cleanup
unreadable_file.chmod(0o644) # reset file permissions for cleanup

0 comments on commit 6a0a47c

Please sign in to comment.