diff --git a/docs/source/examples/02_nesting/03_multiple_subcommands.rst b/docs/source/examples/02_nesting/03_multiple_subcommands.rst index bb62f7a0..5064916a 100644 --- a/docs/source/examples/02_nesting/03_multiple_subcommands.rst +++ b/docs/source/examples/02_nesting/03_multiple_subcommands.rst @@ -50,7 +50,6 @@ Multiple unions over nested types are populated using a series of subcommands. # Train script. - @tyro.conf.configure(tyro.conf.ConsolidateSubcommandArgs) def train( dataset: Mnist | ImageNet = Mnist(), optimizer: Adam | Sgd = Adam(), @@ -69,7 +68,7 @@ Multiple unions over nested types are populated using a series of subcommands. if __name__ == "__main__": - tyro.cli(train) + tyro.cli(train, config=(tyro.conf.ConsolidateSubcommandArgs,)) ------------ diff --git a/docs/source/examples/03_config_systems/01_base_configs.rst b/docs/source/examples/03_config_systems/01_base_configs.rst index d2250b35..0ac47f25 100644 --- a/docs/source/examples/03_config_systems/01_base_configs.rst +++ b/docs/source/examples/03_config_systems/01_base_configs.rst @@ -4,10 +4,13 @@ Base Configurations ========================================== -We can integrate `tyro.cli()` into common configuration patterns: here, we select +We can integrate `tyro` into common configuration patterns: here, we select one of multiple possible base configurations, create a subcommand for each one, and then use the CLI to either override (existing) or fill in (missing) values. +The helper function used here, :func:`tyro.extras.overridable_config_cli()`, is a +lightweight wrapper over :func:`tyro.cli()`. + .. code-block:: python :linenos: @@ -56,9 +59,10 @@ use the CLI to either override (existing) or fill in (missing) values. # Note that we could also define this library using separate YAML files (similar to # `config_path`/`config_name` in Hydra), but staying in Python enables seamless type # checking + IDE support. - Configs = tyro.extras.subcommand_type_from_defaults( - { - "small": ExperimentConfig( + default_configs = { + "small": ( + "Small experiment.", + ExperimentConfig( dataset="mnist", optimizer=AdamOptimizer(), batch_size=2048, @@ -68,7 +72,10 @@ use the CLI to either override (existing) or fill in (missing) values. seed=0, activation=nn.ReLU, ), - "big": ExperimentConfig( + ), + "big": ( + "Big experiment.", + ExperimentConfig( dataset="imagenet-50", optimizer=AdamOptimizer(), batch_size=32, @@ -78,11 +85,10 @@ use the CLI to either override (existing) or fill in (missing) values. seed=0, activation=nn.GELU, ), - } - ) - + ), + } if __name__ == "__main__": - config = tyro.cli(Configs) + config = tyro.extras.overridable_config_cli(default_configs) print(config) ------------ diff --git a/examples/03_config_systems/01_base_configs.py b/examples/03_config_systems/01_base_configs.py index 0d2a59e2..50ee71fa 100644 --- a/examples/03_config_systems/01_base_configs.py +++ b/examples/03_config_systems/01_base_configs.py @@ -1,9 +1,12 @@ """Base Configurations -We can integrate `tyro.cli()` into common configuration patterns: here, we select +We can integrate `tyro` into common configuration patterns: here, we select one of multiple possible base configurations, create a subcommand for each one, and then use the CLI to either override (existing) or fill in (missing) values. +The helper function used here, :func:`tyro.extras.overridable_config_cli()`, is a +lightweight wrapper over :func:`tyro.cli()`. + Usage: `python ./01_base_configs.py --help` @@ -56,9 +59,10 @@ class ExperimentConfig: # Note that we could also define this library using separate YAML files (similar to # `config_path`/`config_name` in Hydra), but staying in Python enables seamless type # checking + IDE support. -Configs = tyro.extras.subcommand_type_from_defaults( - { - "small": ExperimentConfig( +default_configs = { + "small": ( + "Small experiment.", + ExperimentConfig( dataset="mnist", optimizer=AdamOptimizer(), batch_size=2048, @@ -68,7 +72,10 @@ class ExperimentConfig: seed=0, activation=nn.ReLU, ), - "big": ExperimentConfig( + ), + "big": ( + "Big experiment.", + ExperimentConfig( dataset="imagenet-50", optimizer=AdamOptimizer(), batch_size=32, @@ -78,9 +85,8 @@ class ExperimentConfig: seed=0, activation=nn.GELU, ), - } -) - + ), +} if __name__ == "__main__": - config = tyro.cli(Configs) + config = tyro.extras.overridable_config_cli(default_configs) print(config) diff --git a/pyproject.toml b/pyproject.toml index aa3294d0..821dcaf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "tyro" authors = [ {name = "brentyi", email = "brentyi@berkeley.edu"}, ] -version = "0.8.7" # TODO: currently needs to be synchronized manually with __init__.py. +version = "0.8.8" # TODO: currently needs to be synchronized manually with __init__.py. description = "Strongly typed, zero-effort CLI interfaces" readme = "README.md" license = { text="MIT" } diff --git a/src/tyro/__init__.py b/src/tyro/__init__.py index 89528c40..e7bafa09 100644 --- a/src/tyro/__init__.py +++ b/src/tyro/__init__.py @@ -14,4 +14,4 @@ # TODO: this should be synchronized automatically with the pyproject.toml. -__version__ = "0.8.7" +__version__ = "0.8.8" diff --git a/src/tyro/extras/__init__.py b/src/tyro/extras/__init__.py index 567ddaf3..0afb6b85 100644 --- a/src/tyro/extras/__init__.py +++ b/src/tyro/extras/__init__.py @@ -4,6 +4,7 @@ from .._argparse_formatter import set_accent_color as set_accent_color from .._cli import get_parser as get_parser +from ._base_configs import overridable_config_cli as overridable_config_cli from ._base_configs import ( subcommand_type_from_defaults as subcommand_type_from_defaults, ) diff --git a/src/tyro/extras/_base_configs.py b/src/tyro/extras/_base_configs.py index 117a944a..33157e07 100644 --- a/src/tyro/extras/_base_configs.py +++ b/src/tyro/extras/_base_configs.py @@ -1,13 +1,72 @@ -from typing import Mapping, TypeVar, Union +from typing import Mapping, Optional, Sequence, Tuple, TypeVar, Union from typing_extensions import Annotated from .._typing import TypeForm -from ..conf import subcommand T = TypeVar("T") +def overridable_config_cli( + configs: Mapping[str, Tuple[str, T]], + *, + args: Optional[Sequence[str]] = None, +) -> T: + """Helper function for creating a CLI interface that allows us to choose + between default config objects (typically dataclasses) and override values + within it. Turns off subcommand creation for any union types within the + config object. + + This is a lightweight wrapper over :func:`tyro.cli()`, with some default + arguments populated. Also see + :func:`tyro.extras.subcommand_type_from_defaults()`. + + + Example usage: + ```python + import dataclasses + + import tyro + + + @dataclasses.dataclass + class Config: + a: int + b: str + + + default_configs = { + "small": ( + "Small config", + Config(1, "small"), + ), + "big": ( + "Big config", + Config(100, "big"), + ), + } + config = tyro.extras.overridable_config_cli(default_configs) + print(config) + ``` + + Args: + configs: A dictionary of config names mapped to a tuple of + (description, config object). + args: Optional arguments to pass to the CLI. + """ + import tyro + + return tyro.cli( + tyro.extras.subcommand_type_from_defaults( + defaults={k: v[1] for k, v in configs.items()}, + descriptions={k: v[0] for k, v in configs.items()}, + ), + # Don't create subcommands for union types within the config object. + config=(tyro.conf.AvoidSubcommands,), + args=args, + ) + + def subcommand_type_from_defaults( defaults: Mapping[str, T], descriptions: Mapping[str, str] = {}, @@ -67,6 +126,8 @@ def subcommand_type_from_defaults( Returns: A subcommand type, which can be passed to :func:`tyro.cli`. """ + import tyro + # We need to form a union type, which requires at least two elements. assert len(defaults) >= 2, "At least two subcommands are required." return Union.__getitem__( # type: ignore @@ -74,7 +135,7 @@ def subcommand_type_from_defaults( Annotated.__class_getitem__( # type: ignore ( type(v), - subcommand( + tyro.conf.subcommand( k, default=v, description=descriptions.get(k, ""), diff --git a/tests/test_base_configs_nested.py b/tests/test_base_configs_nested.py index 7500df32..f9dca8f0 100644 --- a/tests/test_base_configs_nested.py +++ b/tests/test_base_configs_nested.py @@ -210,3 +210,23 @@ def test_pernicious_override(): ).data_config.test == 0 ) + + +def test_overridable_config_helper(): + assert tyro.extras.overridable_config_cli( + { + "small-data": ( + "Small data", + DataConfig( + test=2221, + ), + ), + "big-data": ( + "Big data", + DataConfig( + test=2, + ), + ), + }, + args=["small-data", "--test", "100"], + ) == DataConfig(100)