Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wolfram Alpha transform support in autolabel #684

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/autolabel/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .serp_api import SerpApi
from .serper_api import SerperApi
from .webpage_transform import WebpageTransform
from .wolfram_alpha import WolframAlpha
from .image import ImageTransform
from typing import Dict
from autolabel.transforms.schema import TransformType
Expand All @@ -18,6 +19,7 @@
TransformType.IMAGE: ImageTransform,
TransformType.WEB_SEARCH_SERP_API: SerpApi,
TransformType.WEB_SEARCH_SERPER: SerperApi,
TransformType.WOLFRAM_ALPHA_API: WolframAlpha,
}


Expand Down
1 change: 1 addition & 0 deletions src/autolabel/transforms/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TransformType(str, Enum):
IMAGE = "image"
WEB_SEARCH_SERP_API = "web_search_serp_api"
WEB_SEARCH_SERPER = "web_search"
WOLFRAM_ALPHA_API = "wolfram_alpha"


class TransformCacheEntry(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion src/autolabel/transforms/serp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
f"Missing query column: {col} in row {row}",
)
query = self.query_template.format(**row)
search_result = self.NULL_TRANSFORM_TOKEN
search_result = {}
if pd.isna(query) or query == self.NULL_TRANSFORM_TOKEN:
raise TransformError(
TransformErrorType.INVALID_INPUT,
Expand Down
2 changes: 1 addition & 1 deletion src/autolabel/transforms/serper_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
f"Missing query column: {col} in row {row}",
)
query = self.query_template.format(**row)
search_result = self.NULL_TRANSFORM_TOKEN
search_result = {}
if pd.isna(query) or query == self.NULL_TRANSFORM_TOKEN:
raise TransformError(
TransformErrorType.INVALID_INPUT,
Expand Down
148 changes: 148 additions & 0 deletions src/autolabel/transforms/wolfram_alpha.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from autolabel.transforms.schema import (
TransformType,
TransformError,
TransformErrorType,
)
from autolabel.transforms import BaseTransform
from typing import Dict, Any, List
import asyncio
import logging
import pandas as pd
import ssl

from autolabel.cache import BaseCache

logger = logging.getLogger(__name__)

MAX_RETRIES = 3
MAX_KEEPALIVE_CONNECTIONS = 20
MAX_CONNECTIONS = 100
BACKOFF = 2
HEADERS = {}
API_BASE_URL = "https://www.wolframalpha.com/api/v1/llm-api"


class WolframAlpha(BaseTransform):
COLUMN_NAMES = [
"result",
]

def __init__(
self,
cache: BaseCache,
output_columns: Dict[str, Any],
query_columns: List[str],
query_template: str,
wolfram_app_id: str,
wolfram_args: Dict[str, Any] = {},
timeout: int = 5,
) -> None:
super().__init__(cache, output_columns)
self.max_retries = MAX_RETRIES
self.query_columns = query_columns
self.query_template = query_template
self.wolfram_app_id = wolfram_app_id
self.wolfram_args = wolfram_args
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.wolfram_app_id}",
}
try:
import httpx

self.httpx = httpx
self.timeout_time = timeout
self.timeout = httpx.Timeout(timeout)
limits = httpx.Limits(
max_keepalive_connections=MAX_KEEPALIVE_CONNECTIONS,
max_connections=MAX_CONNECTIONS,
keepalive_expiry=timeout,
)
self.client = httpx.AsyncClient(
timeout=self.timeout, limits=limits, follow_redirects=True
)
self.client_with_no_verify = httpx.AsyncClient(
timeout=self.timeout, limits=limits, follow_redirects=True, verify=False
)
except ImportError:
raise ImportError(
"httpx is required to use the wolfram alpha transform. Please install them with the following command: pip install httpx"
)

def name(self) -> str:
return TransformType.WOLFRAM_ALPHA_API

async def _get_result(
self, query: str, verify=True, headers=HEADERS, retry_count=0
) -> Dict[str, Any]:
if retry_count >= self.max_retries:
logger.warning(f"Max retries reached for query: {query}")
raise TransformError(
TransformErrorType.MAX_RETRIES_REACHED, "Max retries reached"
)

try:
client = self.client
if not verify:
client = self.client_with_no_verify
params = self.wolfram_args
params["input"] = query
response = await client.get(API_BASE_URL, headers=headers, params=params)
if response.status_code != 200:
logger.debug(
f"Error fetching content. Status code: {response.status_code}"
)
raise TransformError(
TransformErrorType.TRANSFORM_ERROR,
f"Error fetching content. Status code: {response.status_code}",
)
return {
"result": response.text,
}
except self.httpx.ConnectTimeout as e:
logger.error(f"Timeout when fetching content")
raise TransformError(
TransformErrorType.TRANSFORM_TIMEOUT,
"Timeout when fetching content",
)
except ssl.SSLCertVerificationError as e:
logger.warning(
f"SSL verification error when fetching content, retrying with verify=False"
)
await asyncio.sleep(BACKOFF**retry_count)
return await self._get_result(
query, verify=False, headers=headers, retry_count=retry_count + 1
)
except Exception as e:
logger.error(f"Error fetching content. Exception: {e}")
raise e

async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
for col in self.query_columns:
if col not in row:
raise TransformError(
TransformErrorType.INVALID_INPUT,
f"Missing query column: {col} in row {row}",
)
query = self.query_template.format(**row)
result = {}
if pd.isna(query) or query == self.NULL_TRANSFORM_TOKEN:
raise TransformError(
TransformErrorType.INVALID_INPUT,
f"Empty query in row {row}",
)
else:
result = await self._get_result(query, headers=self.headers)

transformed_row = {self.output_columns["result"]: result}

return self._return_output_row(transformed_row)

def params(self):
return {
"query_columns": self.query_columns,
"query_template": self.query_template,
"output_columns": self.output_columns,
"wolfram_app_id": self.wolfram_app_id,
"wolfram_args": self.wolfram_args,
}
Loading