-
Notifications
You must be signed in to change notification settings - Fork 271
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
264 changed files
with
30,061 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
FROM python:3.10-slim | ||
|
||
ARG ROOT_PATH=/ | ||
ENV ROOT_PATH=${ROOT_PATH} | ||
|
||
WORKDIR /app | ||
|
||
COPY . . | ||
|
||
RUN pip install --upgrade pip \ | ||
&& pip install --no-cache-dir openai python-dotenv uvicorn "litellm>=1.0,<2.0" google-generativeai agenta | ||
|
||
|
||
EXPOSE 80 | ||
|
||
CMD ["./entrypoint.sh"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
FROM python:3.10-slim | ||
|
||
ARG ROOT_PATH=/ | ||
ENV ROOT_PATH=${ROOT_PATH} | ||
|
||
WORKDIR /app | ||
|
||
COPY . . | ||
|
||
RUN pip install --upgrade pip \ | ||
&& pip install --no-cache-dir openai python-dotenv uvicorn "litellm>=1.0,<2.0" google-generativeai \ | ||
&& pip install --no-cache-dir --pre agenta | ||
|
||
EXPOSE 80 | ||
|
||
CMD ["./entrypoint.sh"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import agenta as ag | ||
import litellm | ||
from supported_llm_models import get_all_supported_llm_models | ||
|
||
litellm.drop_params = True | ||
|
||
|
||
prompts = { | ||
"system_prompt": "You are an expert in geography.", | ||
"user_prompt": """What is the capital of {country}?""", | ||
} | ||
|
||
GPT_FORMAT_RESPONSE = ["gpt-3.5-turbo-1106", "gpt-4-1106-preview"] | ||
|
||
|
||
ag.init() | ||
ag.config.default( | ||
temperature=ag.FloatParam(default=1, minval=0.0, maxval=2.0), | ||
model=ag.GroupedMultipleChoiceParam( | ||
default="gpt-3.5-turbo", choices=get_all_supported_llm_models() | ||
), | ||
max_tokens=ag.IntParam(-1, -1, 4000), | ||
prompt_system=ag.TextParam(prompts["system_prompt"]), | ||
prompt_user=ag.TextParam(prompts["user_prompt"]), | ||
top_p=ag.FloatParam(1), | ||
frequence_penalty=ag.FloatParam(default=0.0, minval=-2.0, maxval=2.0), | ||
presence_penalty=ag.FloatParam(default=0.0, minval=-2.0, maxval=2.0), | ||
force_json=ag.BinaryParam(False), | ||
) | ||
|
||
|
||
@ag.instrument(spankind="llm") | ||
async def llm_call(prompt_system: str, prompt_user: str): | ||
response_format = ( | ||
{"type": "json_object"} | ||
if ag.config.force_json and ag.config.model in GPT_FORMAT_RESPONSE | ||
else {"type": "text"} | ||
) | ||
max_tokens = ag.config.max_tokens if ag.config.max_tokens != -1 else None | ||
|
||
# Include frequency_penalty and presence_penalty only if supported | ||
completion_params = {} | ||
if ag.config.model in GPT_FORMAT_RESPONSE: | ||
completion_params["frequency_penalty"] = ag.config.frequence_penalty | ||
completion_params["presence_penalty"] = ag.config.presence_penalty | ||
|
||
response = await litellm.acompletion( | ||
**{ | ||
"model": ag.config.model, | ||
"messages": [ | ||
{"content": prompt_system, "role": "system"}, | ||
{"content": prompt_user, "role": "user"}, | ||
], | ||
"temperature": ag.config.temperature, | ||
"max_tokens": max_tokens, | ||
"top_p": ag.config.top_p, | ||
"response_format": response_format, | ||
**completion_params, | ||
} | ||
) | ||
token_usage = response.usage.dict() | ||
return { | ||
"message": response.choices[0].message.content, | ||
"usage": token_usage, | ||
"cost": litellm.cost_calculator.completion_cost( | ||
completion_response=response, model=ag.config.model | ||
), | ||
} | ||
|
||
|
||
@ag.entrypoint | ||
@ag.instrument() | ||
async def generate( | ||
inputs: ag.DictInput = ag.DictInput(default_keys=["country"]), | ||
): | ||
try: | ||
prompt_user = ag.config.prompt_user.format(**inputs) | ||
except Exception as e: | ||
prompt_user = ag.config.prompt_user | ||
try: | ||
prompt_system = ag.config.prompt_system.format(**inputs) | ||
except Exception as e: | ||
prompt_system = ag.config.prompt_system | ||
|
||
# SET MAX TOKENS - via completion() | ||
if ag.config.force_json and ag.config.model not in GPT_FORMAT_RESPONSE: | ||
raise ValueError( | ||
"Model {} does not support JSON response format".format(ag.config.model) | ||
) | ||
|
||
response = await llm_call(prompt_system=prompt_system, prompt_user=prompt_user) | ||
return { | ||
"message": response["message"], | ||
"usage": response.get("usage", None), | ||
"cost": response.get("cost", None), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Any, Callable, Optional | ||
|
||
from .sdk.utils.preinit import PreInitObject | ||
|
||
import agenta.client.backend.types as client_types # pylint: disable=wrong-import-order | ||
|
||
from .sdk.types import ( | ||
DictInput, | ||
MultipleChoice, | ||
FloatParam, | ||
InFile, | ||
IntParam, | ||
MultipleChoiceParam, | ||
GroupedMultipleChoiceParam, | ||
MessagesInput, | ||
TextParam, | ||
FileInputURL, | ||
BinaryParam, | ||
Prompt, | ||
) | ||
|
||
from .sdk.utils.logging import log as logging | ||
from .sdk.tracing import Tracing, get_tracer | ||
from .sdk.decorators.tracing import instrument | ||
from .sdk.tracing.conventions import Reference | ||
from .sdk.decorators.routing import entrypoint, app, route | ||
from .sdk.agenta_init import Config, AgentaSingleton, init as _init | ||
from .sdk.utils.costs import calculate_token_usage | ||
from .sdk.client import Agenta | ||
from .sdk.litellm import litellm as callbacks | ||
from .sdk.managers.vault import VaultManager | ||
from .sdk.managers.config import ConfigManager | ||
from .sdk.managers.variant import VariantManager | ||
from .sdk.managers.deployment import DeploymentManager | ||
from .sdk import assets as assets | ||
from .sdk import tracer | ||
|
||
config = PreInitObject("agenta.config", Config) | ||
DEFAULT_AGENTA_SINGLETON_INSTANCE = AgentaSingleton() | ||
|
||
types = client_types | ||
|
||
api = None | ||
async_api = None | ||
|
||
tracing = DEFAULT_AGENTA_SINGLETON_INSTANCE.tracing # type: ignore | ||
tracer = get_tracer(tracing) | ||
|
||
|
||
def init( | ||
host: Optional[str] = None, | ||
api_key: Optional[str] = None, | ||
config_fname: Optional[str] = None, | ||
redact: Optional[Callable[..., Any]] = None, | ||
redact_on_error: Optional[bool] = True, | ||
# DEPRECATING | ||
app_id: Optional[str] = None, | ||
): | ||
global api, async_api, tracing, tracer # pylint: disable=global-statement | ||
|
||
_init( | ||
host=host, | ||
api_key=api_key, | ||
config_fname=config_fname, | ||
redact=redact, | ||
redact_on_error=redact_on_error, | ||
app_id=app_id, | ||
) | ||
|
||
api = DEFAULT_AGENTA_SINGLETON_INSTANCE.api # type: ignore | ||
async_api = DEFAULT_AGENTA_SINGLETON_INSTANCE.async_api # type: ignore | ||
|
||
tracing = DEFAULT_AGENTA_SINGLETON_INSTANCE.tracing # type: ignore | ||
tracer = get_tracer(tracing) |
22 changes: 22 additions & 0 deletions
22
services/completion-serverless-sdk/agenta/cli/evaluation_commands.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import click | ||
from agenta.client import client | ||
|
||
|
||
@click.group() | ||
def evaluation(): | ||
"""Commands for evaluations.""" | ||
pass | ||
|
||
|
||
# TODO: Remove hardcoded values | ||
@evaluation.command(name="run") | ||
def run_evaluation_cli(): | ||
"""Run an evaluation.""" | ||
|
||
try: | ||
client.run_evaluation( | ||
app_name="sss", | ||
host="http://localhost", | ||
) | ||
except Exception as ex: | ||
click.echo(click.style(f"Error while running evaluation: {ex}", fg="red")) |
Oops, something went wrong.