Skip to content

Commit

Permalink
feat: introduce first few tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl committed Apr 17, 2024
1 parent 66a7528 commit 82c8fcf
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 1 deletion.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
# evals
Using various instructor clients evaluating the quality and capabilities of extractions and reasoning.

Using various instructor clients evaluating the quality and capabilities of extractions and reasoning.

We'll run these tests and see what ends up failing often.

```
pip install -r requirements.txt
pytest
```

When contributing just make sure everything is as async and we'll handle the rest!

150 changes: 150 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
anthropic==0.24.0
anyio==4.3.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
attrs==23.2.0
Babel==2.14.0
beautifulsoup4==4.12.3
bleach==6.1.0
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
comm==0.2.2
contourpy==1.2.0
cycler==0.12.1
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
diskcache==5.6.3
distro==1.9.0
docstring-parser==0.15
executing==2.0.1
fastjsonschema==2.19.1
filelock==3.13.3
fonttools==4.50.0
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.3.1
h11==0.14.0
httpcore==1.0.4
httpx==0.27.0
huggingface-hub==0.22.2
idna==3.6
importlib_metadata==7.1.0
instructor==1.0.2
ipykernel==6.29.3
ipython==8.22.2
ipywidgets==8.1.2
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.3
joblib==1.3.2
json5==0.9.24
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.4
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyter_server==2.13.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.4
jupyterlab_widgets==3.0.10
kiwisolver==1.4.5
langdetect==1.0.9
litellm==1.34.29
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.8.3
matplotlib-inline==0.1.6
mdurl==0.1.2
mistune==3.0.2
mlxtend==0.23.1
mpmath==1.3.0
multidict==6.0.5
nbclient==0.10.0
nbconvert==7.16.3
nbformat==5.10.3
nest-asyncio==1.6.0
networkx==3.2.1
notebook==7.1.2
notebook_shim==0.2.4
numpy==1.26.4
openai==1.14.3
overrides==7.7.0
packaging==24.0
pandas==2.2.1
pandocfilters==1.5.1
parso==0.8.3
pexpect==4.9.0
pillow==10.2.0
platformdirs==4.2.0
prometheus_client==0.20.0
prompt-toolkit==3.0.43
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
pydantic==2.7.0b1
pydantic_core==2.18.0
Pygments==2.17.2
pyparsing==3.1.2
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==2.0.7
pytz==2024.1
PyYAML==6.0.1
pyzmq==25.1.2
qtconsole==5.5.1
QtPy==2.4.1
referencing==0.34.0
regex==2023.12.25
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rpds-py==0.18.0
scikit-learn==1.4.1.post1
scipy==1.12.0
Send2Trash==1.8.2
six==1.16.0
sniffio==1.3.1
soupsieve==2.5
stack-data==0.6.3
sympy==1.12
tenacity==8.2.3
terminado==0.18.1
threadpoolctl==3.4.0
tiktoken==0.6.0
tinycss2==1.2.1
tokenizers==0.15.2
torch==2.2.2
tornado==6.4
tqdm==4.66.2
traitlets==5.14.2
typer==0.9.4
types-python-dateutil==2.9.0.20240316
typing_extensions==4.10.0
tzdata==2024.1
uri-template==1.3.0
urllib3==2.2.1
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
widgetsnbextension==4.0.10
yarl==1.9.4
zipp==3.18.1
36 changes: 36 additions & 0 deletions tests/test_classification_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from itertools import product
from typing import Literal
from util import clients
from pydantic import BaseModel

import pytest


class ClassifySpam(BaseModel):
label: Literal["spam", "not_spam"]


data = [
("I am a spammer who sends many emails every day", "spam"),
("I am a responsible person who does not spam", "not_spam"),
]


@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize("client, data", product(clients, data))
async def test_classification(client, data):
input, expected = data
prediction = await client.create(
response_model=ClassifySpam,
messages=[
{
"role": "system",
"content": "Classify this text as 'spam' or 'not_spam'.",
},
{
"role": "user",
"content": input,
},
],
)
assert prediction.label == expected
98 changes: 98 additions & 0 deletions tests/test_vision_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from openai import OpenAI
from io import StringIO
from typing import Annotated, Any, List
from pydantic import (
BaseModel,
BeforeValidator,
PlainSerializer,
InstanceOf,
WithJsonSchema,
)
import pandas as pd
import pytest
from itertools import product
from util import clients
from instructor import AsyncInstructor


def md_to_df(data: Any) -> Any:
if isinstance(data, str):
return (
pd.read_csv(
StringIO(data), # Get rid of whitespaces
sep="|",
index_col=1,
)
.dropna(axis=1, how="all")
.iloc[1:]
.map(lambda x: x.strip())
) # type: ignore
return data


MarkdownDataFrame = Annotated[
InstanceOf[pd.DataFrame],
BeforeValidator(md_to_df),
PlainSerializer(lambda x: x.to_markdown()),
WithJsonSchema(
{
"type": "string",
"description": """
The markdown representation of the table,
each one should be tidy, do not try to join tables
that should be seperate""",
}
),
]


class Table(BaseModel):
caption: str
dataframe: MarkdownDataFrame


class MultipleTables(BaseModel):
tables: List[Table]


urls = [
"https://a.storyblok.com/f/47007/2400x1260/f816b031cb/uk-ireland-in-three-charts_chart_a.png/m/2880x0",
"https://a.storyblok.com/f/47007/2400x2000/bf383abc3c/231031_uk-ireland-in-three-charts_table_v01_b.png/m/2880x0",
]


@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize("client, url", product(clients, urls))
async def test_extract(client: AsyncInstructor, url: str):

if client.kwargs["model"] != "gpt-4-turbo":
pytest.skip("Only OpenAI supported for now, we need to support images for both")

resp = await client.create(
response_model=MultipleTables,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": """
First, analyze the image to determine the most appropriate headers for the tables.
Generate a descriptive h1 for the overall image, followed by a brief summary of the data it contains.
For each identified table, create an informative h2 title and a concise description of its contents.
Finally, output the markdown representation of each table.
Make sure to escape the markdown table properly, and make sure to include the caption and the dataframe.
including escaping all the newlines and quotes. Only return a markdown table in dataframe, nothing else.
""",
},
{
"type": "image_url",
"image_url": {"url": url},
},
],
}
],
)
assert len(resp.tables) > 0
39 changes: 39 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from openai import AsyncOpenAI
from anthropic import AsyncAnthropic
import instructor
from enum import Enum


class Models(str, Enum):
GPT35TURBO = "gpt-3.5-turbo"
GPT4TURBO = "gpt-4-turbo"
CLAUDE3_SONNET = "claude-3-sonnet-20240229"
CLAUDE3_OPUS = "claude-3-opus-20240229"
CLAUDE3_HAIKU = "claude-3-haiku-20240307"


clients = (
instructor.from_openai(
AsyncOpenAI(),
model=Models.GPT35TURBO,
),
instructor.from_openai(
AsyncOpenAI(),
model=Models.GPT4TURBO,
),
instructor.from_anthropic(
AsyncAnthropic(),
model=Models.CLAUDE3_OPUS,
max_tokens=4000,
),
instructor.from_anthropic(
AsyncAnthropic(),
model=Models.CLAUDE3_SONNET,
max_tokens=4000,
),
instructor.from_anthropic(
AsyncAnthropic(),
model=Models.CLAUDE3_HAIKU,
max_tokens=4000,
),
)

0 comments on commit 82c8fcf

Please sign in to comment.