Skip to content

Commit

Permalink
Merge pull request #16 from Neverbolt/parameter-descriptions
Browse files Browse the repository at this point in the history
Adds the possibility to define help text for parameters
  • Loading branch information
andreashappe authored Apr 19, 2024
2 parents 7436a5d + 544f3b0 commit 0820133
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 22 deletions.
4 changes: 2 additions & 2 deletions usecases/usecase/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import Dict, Type

from utils.configurable import get_parameters, ParameterDefinitions, build_parser, get_arguments
from utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters


class UseCase(abc.ABC):
Expand Down Expand Up @@ -66,7 +66,7 @@ def use_case(name: str, desc: str):
def inner(cls: Type[UseCase]):
if name in use_cases:
raise IndexError(f"Use case with name {name} already exists")
use_cases[name] = _WrappedUseCase(name, desc, cls, get_parameters(cls.__init__, name))
use_cases[name] = _WrappedUseCase(name, desc, cls, get_class_parameters(cls, name))

return cls

Expand Down
56 changes: 47 additions & 9 deletions utils/configurable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import dataclasses
import inspect
import os
from dataclasses import dataclass
Expand All @@ -12,6 +13,16 @@
load_dotenv()


def parameter(*, desc: str, default=dataclasses.MISSING, init: bool = True, repr: bool = True, hash=None,
compare: bool = True, metadata: Dict = None, kw_only: bool = dataclasses.MISSING) -> dataclasses.Field:
if metadata is None:
metadata = dict()
metadata["desc"] = desc

return dataclasses.field(default=default, default_factory=dataclasses.MISSING, init=init, repr=repr, hash=hash,
compare=compare, metadata=metadata, kw_only=kw_only)


def get_default(key, default):
return os.getenv(key, os.getenv(key.upper(), os.getenv(key.replace(".", "_"), os.getenv(key.replace(".", "_").upper(), default))))

Expand All @@ -24,12 +35,14 @@ class ParameterDefinition:
name: str
type: Type
default: Any
description: str

def parser(self, basename: str, parser: argparse.ArgumentParser):
name = f"{basename}{self.name}"
default = get_default(name, self.default)

parser.add_argument(f"--{name}", type=self.type, default=default, required=default is None)
parser.add_argument(f"--{name}", type=self.type, default=default, required=default is None,
help=self.description)

def get(self, basename: str, args: argparse.Namespace):
return getattr(args, f"{basename}{self.name}")
Expand Down Expand Up @@ -62,7 +75,18 @@ def get(self, basename: str, args: argparse.Namespace):
return parameter


def get_parameters(fun, basename: str) -> ParameterDefinitions:
def get_class_parameters(cls, name: str = None, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions:
if name is None:
name = cls.__name__
if fields is None and hasattr(cls, "__dataclass_fields__"):
fields = cls.__dataclass_fields__
return get_parameters(cls.__init__, name, fields)


def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions:
if fields is None:
fields = dict()

sig = inspect.signature(fun)
params: ParameterDefinitions = {}
for name, param in sig.parameters.items():
Expand All @@ -73,13 +97,27 @@ def get_parameters(fun, basename: str) -> ParameterDefinitions:
raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have a type annotation")

default = param.default if param.default != inspect.Parameter.empty else None

if hasattr(param.annotation, "__parameters__"):
params[name] = ComplexParameterDefinition(name, param.annotation, default, get_parameters(param.annotation, f"{basename}.{fun.__name__}"))
elif param.annotation in (str, int, bool):
params[name] = ParameterDefinition(name, param.annotation, default)
description = None
type = param.annotation

field = None
if isinstance(default, dataclasses.Field):
field = default
default = field.default
elif name in fields:
field = fields[name]

if field is not None:
description = field.metadata.get("desc", None)
if field.type is not None:
type = field.type

if hasattr(type, "__parameters__"):
params[name] = ComplexParameterDefinition(name, type, default, description, get_class_parameters(type, f"{basename}.{fun.__name__}"))
elif type in (str, int, bool):
params[name] = ParameterDefinition(name, type, default, description)
else:
raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {param.annotation}")
raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {type}")

return params

Expand All @@ -106,7 +144,7 @@ def inner(cls) -> Configurable:
cls.name = service_name
cls.description = service_desc
cls.__service__ = True
cls.__parameters__ = get_parameters(cls.__init__, cls.__name__)
cls.__parameters__ = get_class_parameters(cls)

return cls

Expand Down
4 changes: 2 additions & 2 deletions utils/db_storage/db_storage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import sqlite3

from utils.configurable import configurable
from utils.configurable import configurable, parameter


@configurable("db_storage", "Stores the results of the experiments in a SQLite database")
class DbStorage:
def __init__(self, connection_string: str = ":memory:"):
def __init__(self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default=":memory:")):
self.connection_string = connection_string

def init(self):
Expand Down
16 changes: 8 additions & 8 deletions utils/openai/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import tiktoken

from utils.configurable import configurable
from utils.configurable import configurable, parameter
from utils.llm_util import LLMResult, LLM


Expand All @@ -20,13 +20,13 @@ class OpenAIConnection(LLM):
If you really must use it, you can import it directly from the utils.openai.openai_llm module, which will later on
show you, that you did not specialize yet.
"""
api_key: str
model: str
context_size: int
api_url: str = "https://api.openai.com"
api_timeout: int = 240
api_backoff: int = 60
api_retries: int = 3
api_key: str = parameter(desc="OpenAI API Key")
model: str = parameter(desc="OpenAI model name")
context_size: int = parameter(desc="Maximum context size for the model, only used internally for things like trimming to the context size")
api_url: str = parameter(desc="URL of the OpenAI API", default="https://api.openai.com")
api_timeout: int = parameter(desc="Timeout for the API request", default=240)
api_backoff: int = parameter(desc="Backoff time in seconds when running into rate-limits", default=60)
api_retries: int = parameter(desc="Number of retries when running into rate-limits", default=3)

def get_response(self, prompt, *, retry: int = 0, **kwargs) -> LLMResult:
if retry >= self.api_retries:
Expand Down
2 changes: 1 addition & 1 deletion wintermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def main():
for name, use_case in use_cases.items():
use_case.build_parser(subparser.add_parser(
name=use_case.name,
description=use_case.description
help=use_case.description
))

parsed = parser.parse_args(sys.argv[1:])
Expand Down

0 comments on commit 0820133

Please sign in to comment.