Skip to content

Commit

Permalink
0.8.8, add tyro.extras.overridable_config_cli()
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Aug 15, 2024
1 parent cb76e7b commit 956ab5e
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 25 deletions.
3 changes: 1 addition & 2 deletions docs/source/examples/02_nesting/03_multiple_subcommands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ Multiple unions over nested types are populated using a series of subcommands.
# Train script.
@tyro.conf.configure(tyro.conf.ConsolidateSubcommandArgs)
def train(
dataset: Mnist | ImageNet = Mnist(),
optimizer: Adam | Sgd = Adam(),
Expand All @@ -69,7 +68,7 @@ Multiple unions over nested types are populated using a series of subcommands.
if __name__ == "__main__":
tyro.cli(train)
tyro.cli(train, config=(tyro.conf.ConsolidateSubcommandArgs,))
------------

Expand Down
24 changes: 15 additions & 9 deletions docs/source/examples/03_config_systems/01_base_configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
Base Configurations
==========================================

We can integrate `tyro.cli()` into common configuration patterns: here, we select
We can integrate `tyro` into common configuration patterns: here, we select
one of multiple possible base configurations, create a subcommand for each one, and then
use the CLI to either override (existing) or fill in (missing) values.

The helper function used here, :func:`tyro.extras.overridable_config_cli()`, is a
lightweight wrapper over :func:`tyro.cli()`.


.. code-block:: python
:linenos:
Expand Down Expand Up @@ -56,9 +59,10 @@ use the CLI to either override (existing) or fill in (missing) values.
# Note that we could also define this library using separate YAML files (similar to
# `config_path`/`config_name` in Hydra), but staying in Python enables seamless type
# checking + IDE support.
Configs = tyro.extras.subcommand_type_from_defaults(
{
"small": ExperimentConfig(
default_configs = {
"small": (
"Small experiment.",
ExperimentConfig(
dataset="mnist",
optimizer=AdamOptimizer(),
batch_size=2048,
Expand All @@ -68,7 +72,10 @@ use the CLI to either override (existing) or fill in (missing) values.
seed=0,
activation=nn.ReLU,
),
"big": ExperimentConfig(
),
"big": (
"Big experiment.",
ExperimentConfig(
dataset="imagenet-50",
optimizer=AdamOptimizer(),
batch_size=32,
Expand All @@ -78,11 +85,10 @@ use the CLI to either override (existing) or fill in (missing) values.
seed=0,
activation=nn.GELU,
),
}
)
),
}
if __name__ == "__main__":
config = tyro.cli(Configs)
config = tyro.extras.overridable_config_cli(default_configs)
print(config)
------------
Expand Down
24 changes: 15 additions & 9 deletions examples/03_config_systems/01_base_configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Base Configurations
We can integrate `tyro.cli()` into common configuration patterns: here, we select
We can integrate `tyro` into common configuration patterns: here, we select
one of multiple possible base configurations, create a subcommand for each one, and then
use the CLI to either override (existing) or fill in (missing) values.
The helper function used here, :func:`tyro.extras.overridable_config_cli()`, is a
lightweight wrapper over :func:`tyro.cli()`.
Usage:
`python ./01_base_configs.py --help`
Expand Down Expand Up @@ -56,9 +59,10 @@ class ExperimentConfig:
# Note that we could also define this library using separate YAML files (similar to
# `config_path`/`config_name` in Hydra), but staying in Python enables seamless type
# checking + IDE support.
Configs = tyro.extras.subcommand_type_from_defaults(
{
"small": ExperimentConfig(
default_configs = {
"small": (
"Small experiment.",
ExperimentConfig(
dataset="mnist",
optimizer=AdamOptimizer(),
batch_size=2048,
Expand All @@ -68,7 +72,10 @@ class ExperimentConfig:
seed=0,
activation=nn.ReLU,
),
"big": ExperimentConfig(
),
"big": (
"Big experiment.",
ExperimentConfig(
dataset="imagenet-50",
optimizer=AdamOptimizer(),
batch_size=32,
Expand All @@ -78,9 +85,8 @@ class ExperimentConfig:
seed=0,
activation=nn.GELU,
),
}
)

),
}
if __name__ == "__main__":
config = tyro.cli(Configs)
config = tyro.extras.overridable_config_cli(default_configs)
print(config)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "tyro"
authors = [
{name = "brentyi", email = "[email protected]"},
]
version = "0.8.7" # TODO: currently needs to be synchronized manually with __init__.py.
version = "0.8.8" # TODO: currently needs to be synchronized manually with __init__.py.
description = "Strongly typed, zero-effort CLI interfaces"
readme = "README.md"
license = { text="MIT" }
Expand Down
2 changes: 1 addition & 1 deletion src/tyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@


# TODO: this should be synchronized automatically with the pyproject.toml.
__version__ = "0.8.7"
__version__ = "0.8.8"
1 change: 1 addition & 0 deletions src/tyro/extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .._argparse_formatter import set_accent_color as set_accent_color
from .._cli import get_parser as get_parser
from ._base_configs import overridable_config_cli as overridable_config_cli
from ._base_configs import (
subcommand_type_from_defaults as subcommand_type_from_defaults,
)
Expand Down
67 changes: 64 additions & 3 deletions src/tyro/extras/_base_configs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,72 @@
from typing import Mapping, TypeVar, Union
from typing import Mapping, Optional, Sequence, Tuple, TypeVar, Union

from typing_extensions import Annotated

from .._typing import TypeForm
from ..conf import subcommand

T = TypeVar("T")


def overridable_config_cli(
configs: Mapping[str, Tuple[str, T]],
*,
args: Optional[Sequence[str]] = None,
) -> T:
"""Helper function for creating a CLI interface that allows us to choose
between default config objects (typically dataclasses) and override values
within it. Turns off subcommand creation for any union types within the
config object.
This is a lightweight wrapper over :func:`tyro.cli()`, with some default
arguments populated. Also see
:func:`tyro.extras.subcommand_type_from_defaults()`.
Example usage:
```python
import dataclasses
import tyro
@dataclasses.dataclass
class Config:
a: int
b: str
default_configs = {
"small": (
"Small config",
Config(1, "small"),
),
"big": (
"Big config",
Config(100, "big"),
),
}
config = tyro.extras.overridable_config_cli(default_configs)
print(config)
```
Args:
configs: A dictionary of config names mapped to a tuple of
(description, config object).
args: Optional arguments to pass to the CLI.
"""
import tyro

return tyro.cli(
tyro.extras.subcommand_type_from_defaults(
defaults={k: v[1] for k, v in configs.items()},
descriptions={k: v[0] for k, v in configs.items()},
),
# Don't create subcommands for union types within the config object.
config=(tyro.conf.AvoidSubcommands,),
args=args,
)


def subcommand_type_from_defaults(
defaults: Mapping[str, T],
descriptions: Mapping[str, str] = {},
Expand Down Expand Up @@ -67,14 +126,16 @@ def subcommand_type_from_defaults(
Returns:
A subcommand type, which can be passed to :func:`tyro.cli`.
"""
import tyro

# We need to form a union type, which requires at least two elements.
assert len(defaults) >= 2, "At least two subcommands are required."
return Union.__getitem__( # type: ignore
tuple(
Annotated.__class_getitem__( # type: ignore
(
type(v),
subcommand(
tyro.conf.subcommand(
k,
default=v,
description=descriptions.get(k, ""),
Expand Down
20 changes: 20 additions & 0 deletions tests/test_base_configs_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,23 @@ def test_pernicious_override():
).data_config.test
== 0
)


def test_overridable_config_helper():
assert tyro.extras.overridable_config_cli(
{
"small-data": (
"Small data",
DataConfig(
test=2221,
),
),
"big-data": (
"Big data",
DataConfig(
test=2,
),
),
},
args=["small-data", "--test", "100"],
) == DataConfig(100)

0 comments on commit 956ab5e

Please sign in to comment.