Skip to content

Commit

Permalink
add tests coverage, fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
ollmer committed Mar 4, 2025
1 parent d76c1cf commit a7f2605
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 5 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ lint-check:
test:
@uv run --all-extras pytest -s --color=yes -m "not slow" tests/

coverage:
@uv run --all-extras pytest --cov=tapeagents -s --color=yes -m "not slow" tests/

test-core:
@uv run pytest -s --color=yes tests/ --ignore-glob="tests/*/*"

Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ classifiers = [
dependencies = [
"anthropic>=0.49.0",
"browsergym~=0.13",
"coverage>=7.6.12",
"fastapi~=0.115",
"gradio~=5.11",
"hydra-core~=1.3",
Expand All @@ -36,6 +37,8 @@ dependencies = [
"podman~=5.0",
"pyautogui>=0.9.54",
"pydantic~=2.9",
"pytest-cov>=6.0.0",
"pytest-xdist>=3.6.1",
"pyyaml~=6.0",
"streamlit>=1.42.0",
"tavily-python~=0.3",
Expand Down Expand Up @@ -159,3 +162,6 @@ deps = [
"types-chardet>=5.0.4.6",
]
commands = [["mypy", "tapeagents"]]

[tool.coverage.run]
dynamic_context = "test_function"
3 changes: 2 additions & 1 deletion tapeagents/llms/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tapeagents.llms.base import LLMEvent, LLMOutput
from tapeagents.llms.cached import CachedLLM
from tapeagents.llms.litellm import logger
from tapeagents.utils import get_step_schemas_from_union_type
from tapeagents.utils import get_step_schemas_from_union_type, resize_base64_message


class Claude(CachedLLM):
Expand Down Expand Up @@ -90,6 +90,7 @@ def update_image_messages_format(self, messages: list) -> list:
texts = []
for submessage in message["content"]:
if submessage["type"] == "image_url":
submessage = resize_base64_message(submessage)
url = submessage["image_url"]["url"]
content_type, base64_image = url.split(";base64,")
img_message = {
Expand Down
15 changes: 14 additions & 1 deletion tapeagents/llms/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,20 @@ def _implementation():
logger.warning(f"STEP{i}: {diff_strings(aa, bb)}\n")
raise FatalError("prompt not found")
else:
logger.warning(f"prompt of size {len(prompt_key)} not found, skipping..")
messages_previews = []
for m in prompt.messages:
try:
if isinstance(m["content"], list):
msg = "[text,img]"
else:
m_dict = json.loads(m["content"])
msg = list(m_dict.keys())
messages_previews.append({"role": m["role"], "content": msg})
except Exception:
messages_previews.append({"role": m["role"], "content": "text"})
logger.warning(
f"prompt with {len(prompt.messages)} messages {messages_previews}, {len(prompt_key)} chars not found, skipping.."
)
raise FatalError("prompt not found")
yield LLMEvent(output=LLMOutput(content=output))

Expand Down
2 changes: 1 addition & 1 deletion tapeagents/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def replay_tapes(
raise FatalError("Tape mismatch")
ok += 1
except FatalError as e:
logger.error(colored(f"Fatal error: {e}, skip tape {tape.metadata.id}", "red"))
logger.error(colored(f"Fatal error: {e}, skip tape {i}/{len(tapes)} ({tape.metadata.id})", "red"))
fails += 1
if stop_on_error:
raise e
Expand Down
13 changes: 11 additions & 2 deletions tapeagents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import difflib
import fcntl
import importlib
import io
import json
import os
import tempfile
Expand Down Expand Up @@ -110,8 +111,16 @@ def get_step_schemas_from_union_type(cls, simplify: bool = True) -> str:


def image_base64_message(image_path: str) -> dict:
max_size = 1280
image = Image.open(image_path)
image_extension = os.path.splitext(image_path)[1][1:]
content_type = f"image/{image_extension}"
base64_image = encode_image(image_path)
message = {"type": "image_url", "image_url": {"url": f"data:{content_type};base64,{base64_image}"}}
return message


def resize_base64_message(message: dict, max_size: int = 1280) -> dict:
base64_image = message["image_url"]["url"].split(",", maxsplit=1)[1]
image = Image.open(io.BytesIO(base64.b64decode(base64_image)))
if image.size[0] > max_size or image.size[1] > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
with tempfile.NamedTemporaryFile() as tmp:
Expand Down
86 changes: 86 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a7f2605

Please sign in to comment.