Skip to content

Commit

Permalink
refactor(agent/core): Refactor autogpt.core.configuration.schema an…
Browse files Browse the repository at this point in the history
…d update docstrings

- Refactor the `schema.py` file in the `autogpt.core.configuration` module.
- Added docstring to `SystemConfiguration.from_env()`
- Updated docstrings for functions `_get_user_config_values`, `_get_non_default_user_config_values`, `_recursive_init_model`, `_recurse_user_config_fields`, and `_recurse_user_config_values`.
  • Loading branch information
Pwuts committed Dec 5, 2023
1 parent 5796734 commit 21a36f0
Showing 1 changed file with 52 additions and 56 deletions.
108 changes: 52 additions & 56 deletions autogpts/autogpt/autogpt/core/configuration/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@ def UserConfigurable(

class SystemConfiguration(BaseModel):
def get_user_config(self) -> dict[str, Any]:
return _get_user_config_values(self)
return _recurse_user_config_values(self)

@classmethod
def from_env(cls):
"""
Initializes the config object from environment variables.
Environment variables are mapped to UserConfigurable fields using the from_env
attribute that can be passed to UserConfigurable.
"""

def infer_field_value(field: ModelField):
field_info = field.field_info
default_value = (
Expand Down Expand Up @@ -86,7 +93,7 @@ class Configurable(abc.ABC, Generic[S]):

@classmethod
def get_user_config(cls) -> dict[str, Any]:
return _get_user_config_values(cls.default_settings)
return _recurse_user_config_values(cls.default_settings)

@classmethod
def build_agent_configuration(cls, overrides: dict = {}) -> S:
Expand All @@ -98,38 +105,6 @@ def build_agent_configuration(cls, overrides: dict = {}) -> S:
return cls.default_settings.__class__.parse_obj(final_configuration)


def _get_user_config_values(instance: BaseModel) -> dict[str, Any]:
"""
Get the user config fields of a Pydantic model instance.
Args:
instance: The Pydantic model instance.
Returns:
The user config fields of the instance.
"""
return _recurse_user_config_values(instance)


def _get_non_default_user_config_values(instance: BaseModel) -> dict[str, Any]:
"""
Get the non-default user config fields of a Pydantic model instance.
Args:
instance: The Pydantic model instance.
Returns:
The non-default user config fields of the instance.
"""

def infer_field_value(field: ModelField, value):
default = field.default_factory if field.default_factory else field.default
if value != default:
return value

return remove_none_items(_recurse_user_config_values(instance, infer_field_value))


def _update_user_config_from_env(instance: BaseModel) -> dict[str, Any]:
"""
Update config fields of a Pydantic model instance from environment variables.
Expand All @@ -139,7 +114,7 @@ def _update_user_config_from_env(instance: BaseModel) -> dict[str, Any]:
2. Value returned by `from_env()`
3. Default value for the field
Args:
Params:
instance: The Pydantic model instance.
Returns:
Expand Down Expand Up @@ -176,17 +151,16 @@ def _recursive_init_model(
infer_field_value: Callable[[ModelField], Any],
) -> M:
"""
Recurse through the user config fields of a Pydantic model instance.
Recursively initialize the user configuration fields of a Pydantic model.
Args:
instance: The Pydantic model instance.
process_field: The callback function to process each field.
Params:
Parameters:
model: The Pydantic model type.
infer_field_value: A callback function to infer the value of each field.
Parameters:
ModelField: The Pydantic ModelField object describing the field.
Any: The current value of the field.
Returns:
The processed user config fields of the instance.
BaseModel: An instance of the model with the initialized configuration.
"""
user_config_fields = {}
for name, field in model.__fields__.items():
Expand Down Expand Up @@ -219,17 +193,20 @@ def _recurse_user_config_fields(
] = None,
) -> dict[str, Any]:
"""
Recurse through the user config fields of a Pydantic model instance.
Recursively process the user configuration fields of a Pydantic model instance.
Args:
instance: The Pydantic model instance.
process_field: The callback function to process each field.
Params:
model: The Pydantic model to iterate over.
infer_field_value: A callback function to process each field.
Params:
ModelField: The Pydantic ModelField object describing the field.
Any: The current value of the field.
init_sub_config: An optional callback function to initialize a sub-config.
Params:
Type[SystemConfiguration]: The type of the sub-config to initialize.
Returns:
The processed user config fields of the instance.
dict[str, Any]: The processed user configuration fields of the instance.
"""
user_config_fields = {}

Expand Down Expand Up @@ -285,17 +262,17 @@ def _recurse_user_config_values(
get_field_value: Callable[[ModelField, T], T] = lambda _, v: v,
) -> dict[str, Any]:
"""
Recurse through the user config values of a Pydantic model instance.
This function recursively traverses the user configuration values in a Pydantic
model instance.
Args:
instance: The Pydantic model instance.
process_field: The callback function to process each field.
Params:
ModelField: The Pydantic ModelField object describing the field.
Any: The current value of the field.
Params:
instance: A Pydantic model instance.
get_field_value: A callback function to process each field. Parameters:
ModelField: The Pydantic ModelField object that describes the field.
Any: The current value of the field.
Returns:
The processed user config fields of the instance.
A dictionary containing the processed user configuration fields of the instance.
"""
user_config_values = {}

Expand Down Expand Up @@ -324,11 +301,30 @@ def _recurse_user_config_values(
return user_config_values


def _get_non_default_user_config_values(instance: BaseModel) -> dict[str, Any]:
"""
Get the non-default user config fields of a Pydantic model instance.
Params:
instance: The Pydantic model instance.
Returns:
dict[str, Any]: The non-default user config values on the instance.
"""

def get_field_value(field: ModelField, value):
default = field.default_factory() if field.default_factory else field.default
if value != default:
return value

return remove_none_items(_recurse_user_config_values(instance, get_field_value))


def deep_update(original_dict: dict, update_dict: dict) -> dict:
"""
Recursively update a dictionary.
Args:
Params:
original_dict (dict): The dictionary to be updated.
update_dict (dict): The dictionary to update with.
Expand Down

0 comments on commit 21a36f0

Please sign in to comment.