-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
335 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,13 @@ | ||
# evals | ||
Using various instructor clients evaluating the quality and capabilities of extractions and reasoning. | ||
|
||
Using various instructor clients evaluating the quality and capabilities of extractions and reasoning. | ||
|
||
We'll run these tests and see what ends up failing often. | ||
|
||
``` | ||
pip install -r requirements.txt | ||
pytest | ||
``` | ||
|
||
When contributing just make sure everything is as async and we'll handle the rest! | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
aiohttp==3.9.3 | ||
aiosignal==1.3.1 | ||
annotated-types==0.6.0 | ||
anthropic==0.24.0 | ||
anyio==4.3.0 | ||
appnope==0.1.4 | ||
argon2-cffi==23.1.0 | ||
argon2-cffi-bindings==21.2.0 | ||
arrow==1.3.0 | ||
asttokens==2.4.1 | ||
async-lru==2.0.4 | ||
attrs==23.2.0 | ||
Babel==2.14.0 | ||
beautifulsoup4==4.12.3 | ||
bleach==6.1.0 | ||
certifi==2024.2.2 | ||
cffi==1.16.0 | ||
charset-normalizer==3.3.2 | ||
click==8.1.7 | ||
comm==0.2.2 | ||
contourpy==1.2.0 | ||
cycler==0.12.1 | ||
debugpy==1.8.1 | ||
decorator==5.1.1 | ||
defusedxml==0.7.1 | ||
diskcache==5.6.3 | ||
distro==1.9.0 | ||
docstring-parser==0.15 | ||
executing==2.0.1 | ||
fastjsonschema==2.19.1 | ||
filelock==3.13.3 | ||
fonttools==4.50.0 | ||
fqdn==1.5.1 | ||
frozenlist==1.4.1 | ||
fsspec==2024.3.1 | ||
h11==0.14.0 | ||
httpcore==1.0.4 | ||
httpx==0.27.0 | ||
huggingface-hub==0.22.2 | ||
idna==3.6 | ||
importlib_metadata==7.1.0 | ||
instructor==1.0.2 | ||
ipykernel==6.29.3 | ||
ipython==8.22.2 | ||
ipywidgets==8.1.2 | ||
isoduration==20.11.0 | ||
jedi==0.19.1 | ||
Jinja2==3.1.3 | ||
joblib==1.3.2 | ||
json5==0.9.24 | ||
jsonpointer==2.4 | ||
jsonschema==4.21.1 | ||
jsonschema-specifications==2023.12.1 | ||
jupyter==1.0.0 | ||
jupyter-console==6.6.3 | ||
jupyter-events==0.10.0 | ||
jupyter-lsp==2.2.4 | ||
jupyter_client==8.6.1 | ||
jupyter_core==5.7.2 | ||
jupyter_server==2.13.0 | ||
jupyter_server_terminals==0.5.3 | ||
jupyterlab==4.1.5 | ||
jupyterlab_pygments==0.3.0 | ||
jupyterlab_server==2.25.4 | ||
jupyterlab_widgets==3.0.10 | ||
kiwisolver==1.4.5 | ||
langdetect==1.0.9 | ||
litellm==1.34.29 | ||
markdown-it-py==3.0.0 | ||
MarkupSafe==2.1.5 | ||
matplotlib==3.8.3 | ||
matplotlib-inline==0.1.6 | ||
mdurl==0.1.2 | ||
mistune==3.0.2 | ||
mlxtend==0.23.1 | ||
mpmath==1.3.0 | ||
multidict==6.0.5 | ||
nbclient==0.10.0 | ||
nbconvert==7.16.3 | ||
nbformat==5.10.3 | ||
nest-asyncio==1.6.0 | ||
networkx==3.2.1 | ||
notebook==7.1.2 | ||
notebook_shim==0.2.4 | ||
numpy==1.26.4 | ||
openai==1.14.3 | ||
overrides==7.7.0 | ||
packaging==24.0 | ||
pandas==2.2.1 | ||
pandocfilters==1.5.1 | ||
parso==0.8.3 | ||
pexpect==4.9.0 | ||
pillow==10.2.0 | ||
platformdirs==4.2.0 | ||
prometheus_client==0.20.0 | ||
prompt-toolkit==3.0.43 | ||
psutil==5.9.8 | ||
ptyprocess==0.7.0 | ||
pure-eval==0.2.2 | ||
pycparser==2.21 | ||
pydantic==2.7.0b1 | ||
pydantic_core==2.18.0 | ||
Pygments==2.17.2 | ||
pyparsing==3.1.2 | ||
python-dateutil==2.9.0.post0 | ||
python-dotenv==1.0.1 | ||
python-json-logger==2.0.7 | ||
pytz==2024.1 | ||
PyYAML==6.0.1 | ||
pyzmq==25.1.2 | ||
qtconsole==5.5.1 | ||
QtPy==2.4.1 | ||
referencing==0.34.0 | ||
regex==2023.12.25 | ||
requests==2.31.0 | ||
rfc3339-validator==0.1.4 | ||
rfc3986-validator==0.1.1 | ||
rich==13.7.1 | ||
rpds-py==0.18.0 | ||
scikit-learn==1.4.1.post1 | ||
scipy==1.12.0 | ||
Send2Trash==1.8.2 | ||
six==1.16.0 | ||
sniffio==1.3.1 | ||
soupsieve==2.5 | ||
stack-data==0.6.3 | ||
sympy==1.12 | ||
tenacity==8.2.3 | ||
terminado==0.18.1 | ||
threadpoolctl==3.4.0 | ||
tiktoken==0.6.0 | ||
tinycss2==1.2.1 | ||
tokenizers==0.15.2 | ||
torch==2.2.2 | ||
tornado==6.4 | ||
tqdm==4.66.2 | ||
traitlets==5.14.2 | ||
typer==0.9.4 | ||
types-python-dateutil==2.9.0.20240316 | ||
typing_extensions==4.10.0 | ||
tzdata==2024.1 | ||
uri-template==1.3.0 | ||
urllib3==2.2.1 | ||
wcwidth==0.2.13 | ||
webcolors==1.13 | ||
webencodings==0.5.1 | ||
websocket-client==1.7.0 | ||
widgetsnbextension==4.0.10 | ||
yarl==1.9.4 | ||
zipp==3.18.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from itertools import product | ||
from typing import Literal | ||
from util import clients | ||
from pydantic import BaseModel | ||
|
||
import pytest | ||
|
||
|
||
class ClassifySpam(BaseModel): | ||
label: Literal["spam", "not_spam"] | ||
|
||
|
||
data = [ | ||
("I am a spammer who sends many emails every day", "spam"), | ||
("I am a responsible person who does not spam", "not_spam"), | ||
] | ||
|
||
|
||
@pytest.mark.asyncio_cooperative | ||
@pytest.mark.parametrize("client, data", product(clients, data)) | ||
async def test_classification(client, data): | ||
input, expected = data | ||
prediction = await client.create( | ||
response_model=ClassifySpam, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "Classify this text as 'spam' or 'not_spam'.", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": input, | ||
}, | ||
], | ||
) | ||
assert prediction.label == expected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from openai import OpenAI | ||
from io import StringIO | ||
from typing import Annotated, Any, List | ||
from pydantic import ( | ||
BaseModel, | ||
BeforeValidator, | ||
PlainSerializer, | ||
InstanceOf, | ||
WithJsonSchema, | ||
) | ||
import pandas as pd | ||
import pytest | ||
from itertools import product | ||
from util import clients | ||
from instructor import AsyncInstructor | ||
|
||
|
||
def md_to_df(data: Any) -> Any: | ||
if isinstance(data, str): | ||
return ( | ||
pd.read_csv( | ||
StringIO(data), # Get rid of whitespaces | ||
sep="|", | ||
index_col=1, | ||
) | ||
.dropna(axis=1, how="all") | ||
.iloc[1:] | ||
.map(lambda x: x.strip()) | ||
) # type: ignore | ||
return data | ||
|
||
|
||
MarkdownDataFrame = Annotated[ | ||
InstanceOf[pd.DataFrame], | ||
BeforeValidator(md_to_df), | ||
PlainSerializer(lambda x: x.to_markdown()), | ||
WithJsonSchema( | ||
{ | ||
"type": "string", | ||
"description": """ | ||
The markdown representation of the table, | ||
each one should be tidy, do not try to join tables | ||
that should be seperate""", | ||
} | ||
), | ||
] | ||
|
||
|
||
class Table(BaseModel): | ||
caption: str | ||
dataframe: MarkdownDataFrame | ||
|
||
|
||
class MultipleTables(BaseModel): | ||
tables: List[Table] | ||
|
||
|
||
urls = [ | ||
"https://a.storyblok.com/f/47007/2400x1260/f816b031cb/uk-ireland-in-three-charts_chart_a.png/m/2880x0", | ||
"https://a.storyblok.com/f/47007/2400x2000/bf383abc3c/231031_uk-ireland-in-three-charts_table_v01_b.png/m/2880x0", | ||
] | ||
|
||
|
||
@pytest.mark.asyncio_cooperative | ||
@pytest.mark.parametrize("client, url", product(clients, urls)) | ||
async def test_extract(client: AsyncInstructor, url: str): | ||
|
||
if client.kwargs["model"] != "gpt-4-turbo": | ||
pytest.skip("Only OpenAI supported for now, we need to support images for both") | ||
|
||
resp = await client.create( | ||
response_model=MultipleTables, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": """ | ||
First, analyze the image to determine the most appropriate headers for the tables. | ||
Generate a descriptive h1 for the overall image, followed by a brief summary of the data it contains. | ||
For each identified table, create an informative h2 title and a concise description of its contents. | ||
Finally, output the markdown representation of each table. | ||
Make sure to escape the markdown table properly, and make sure to include the caption and the dataframe. | ||
including escaping all the newlines and quotes. Only return a markdown table in dataframe, nothing else. | ||
""", | ||
}, | ||
{ | ||
"type": "image_url", | ||
"image_url": {"url": url}, | ||
}, | ||
], | ||
} | ||
], | ||
) | ||
assert len(resp.tables) > 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from openai import AsyncOpenAI | ||
from anthropic import AsyncAnthropic | ||
import instructor | ||
from enum import Enum | ||
|
||
|
||
class Models(str, Enum): | ||
GPT35TURBO = "gpt-3.5-turbo" | ||
GPT4TURBO = "gpt-4-turbo" | ||
CLAUDE3_SONNET = "claude-3-sonnet-20240229" | ||
CLAUDE3_OPUS = "claude-3-opus-20240229" | ||
CLAUDE3_HAIKU = "claude-3-haiku-20240307" | ||
|
||
|
||
clients = ( | ||
instructor.from_openai( | ||
AsyncOpenAI(), | ||
model=Models.GPT35TURBO, | ||
), | ||
instructor.from_openai( | ||
AsyncOpenAI(), | ||
model=Models.GPT4TURBO, | ||
), | ||
instructor.from_anthropic( | ||
AsyncAnthropic(), | ||
model=Models.CLAUDE3_OPUS, | ||
max_tokens=4000, | ||
), | ||
instructor.from_anthropic( | ||
AsyncAnthropic(), | ||
model=Models.CLAUDE3_SONNET, | ||
max_tokens=4000, | ||
), | ||
instructor.from_anthropic( | ||
AsyncAnthropic(), | ||
model=Models.CLAUDE3_HAIKU, | ||
max_tokens=4000, | ||
), | ||
) |