Skip to content

Commit

Permalink
Update autogenerated Python>=3.11 tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Aug 8, 2024
1 parent faddd37 commit 29d7d20
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 12 deletions.
1 change: 0 additions & 1 deletion tests/test_py311_generated/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def generate_from_path(test_path: pathlib.Path) -> None:
)
out_path.write_text(content)

subprocess.run(["isort", "--profile=black", str(out_path)], check=True)
subprocess.run(["ruff", "format", str(out_path)], check=True)
subprocess.run(["ruff", "check", "--fix", str(out_path)], check=True)

Expand Down
122 changes: 122 additions & 0 deletions tests/test_py311_generated/test_collections_generated.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import collections
import collections.abc
import contextlib
import dataclasses
import enum
import io
import sys
from typing import (
Any,
Deque,
Expand Down Expand Up @@ -154,6 +156,63 @@ class A:
tyro.cli(A, args=[])


def test_sequences_narrow() -> None:
@dataclasses.dataclass
class A:
x: Sequence = dataclasses.field(default_factory=lambda: [0])

assert tyro.cli(A, args=["--x", "1", "2", "3"]) == A(x=[1, 2, 3])
assert tyro.cli(A, args=[]) == A(x=[0])
assert tyro.cli(A, args=["--x"]) == A(x=[])


def test_sequences_narrow_any() -> None:
@dataclasses.dataclass
class A:
x: Sequence[Any] = dataclasses.field(default_factory=lambda: [0])

assert tyro.cli(A, args=["--x", "1", "2", "3"]) == A(x=[1, 2, 3])
assert tyro.cli(A, args=[]) == A(x=[0])
assert tyro.cli(A, args=["--x"]) == A(x=[])


if sys.version_info >= (3, 9):

def test_abc_sequences() -> None:
@dataclasses.dataclass
class A:
x: collections.abc.Sequence[int]

assert tyro.cli(A, args=["--x", "1", "2", "3"]) == A(x=[1, 2, 3])
assert tyro.cli(A, args=["--x"]) == A(x=[])
with pytest.raises(SystemExit):
tyro.cli(A, args=[])


def test_abc_sequences_narrow() -> None:
@dataclasses.dataclass
class A:
x: collections.abc.Sequence = dataclasses.field(default_factory=lambda: [0])

assert tyro.cli(A, args=["--x", "1", "2", "3"]) == A(x=[1, 2, 3])
assert tyro.cli(A, args=[]) == A(x=[0])
assert tyro.cli(A, args=["--x"]) == A(x=[])


if sys.version_info >= (3, 9):

def test_abc_sequences_narrow_any() -> None:
@dataclasses.dataclass
class A:
x: collections.abc.Sequence[Any] = dataclasses.field(
default_factory=lambda: [0]
)

assert tyro.cli(A, args=["--x", "1", "2", "3"]) == A(x=[1, 2, 3])
assert tyro.cli(A, args=[]) == A(x=[0])
assert tyro.cli(A, args=["--x"]) == A(x=[])


def test_lists() -> None:
@dataclasses.dataclass
class A:
Expand Down Expand Up @@ -446,20 +505,83 @@ def main(x: list = [0, 1, 2, "hello"]) -> Any:
assert tyro.cli(main, args="--x hi there 5".split(" ")) == ["hi", "there", 5]


def test_list_narrowing_any() -> None:
def main(x: List[Any] = [0, 1, 2, "hello"]) -> Any:
return x

assert tyro.cli(main, args="--x hi there 5".split(" ")) == ["hi", "there", 5]


def test_list_narrowing_empty() -> None:
def main(x: list = []) -> Any:
return x

assert tyro.cli(main, args="--x hi there 5".split(" ")) == ["hi", "there", "5"]


def test_list_narrowing_empty_any() -> None:
def main(x: List[Any] = []) -> Any:
return x

assert tyro.cli(main, args="--x hi there 5".split(" ")) == ["hi", "there", "5"]


def test_set_narrowing() -> None:
def main(x: set = {0, 1, 2, "hello"}) -> Any:
return x

assert tyro.cli(main, args="--x hi there 5".split(" ")) == {"hi", "there", 5}


def test_set_narrowing_any() -> None:
def main(x: Set[Any] = {0, 1, 2, "hello"}) -> Any:
return x

assert tyro.cli(main, args="--x hi there 5".split(" ")) == {"hi", "there", 5}


def test_set_narrowing_empty() -> None:
def main(x: set = set()) -> Any:
return x

assert tyro.cli(main, args="--x hi there 5".split(" ")) == {"hi", "there", "5"}


def test_set_narrowing_any_empty() -> None:
def main(x: Set[Any] = set()) -> Any:
return x

assert tyro.cli(main, args="--x hi there 5".split(" ")) == {"hi", "there", "5"}


def test_tuple_narrowing() -> None:
def main(x: tuple = (0, 1, 2, "hello")) -> Any:
return x

assert tyro.cli(main, args="--x 0 1 2 3".split(" ")) == (0, 1, 2, "3")


def test_tuple_narrowing_any() -> None:
def main(x: Tuple[Any, ...] = (0, 1, 2, "hello")) -> Any:
return x

assert tyro.cli(main, args="--x 0 1 2 3".split(" ")) == (0, 1, 2, "3")


def test_tuple_narrowing_empty() -> None:
def main(x: tuple = ()) -> Any:
return x

assert tyro.cli(main, args="--x 0 1 2 3".split(" ")) == ("0", "1", "2", "3")


def test_tuple_narrowing_empty_any() -> None:
def main(x: Tuple[Any, ...] = ()) -> Any:
return x

assert tyro.cli(main, args="--x 0 1 2 3".split(" ")) == ("0", "1", "2", "3")


def test_tuple_narrowing_empty_default() -> None:
def main(x: tuple = ()) -> Any:
return x
Expand Down
38 changes: 38 additions & 0 deletions tests/test_py311_generated/test_conf_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,44 @@ def commit(branch: str) -> int:
)


def test_custom_constructor_10() -> None:
def commit(branch: str) -> int:
"""Commit"""
print(f"commit branch={branch}")
return 3

def inner(x: Annotated[Any, tyro.conf.arg(constructor=commit)]) -> None:
return x

def inner_no_prefix(
x: Annotated[Any, tyro.conf.arg(constructor=commit, prefix_name=False)],
) -> None:
return x

def outer(x: Annotated[Any, tyro.conf.arg(constructor=inner)]) -> None:
return x

def outer_no_prefix(
x: Annotated[Any, tyro.conf.arg(constructor=inner_no_prefix)],
) -> None:
return x

assert (
tyro.cli(
outer,
args="--x.x.branch 5".split(" "),
)
== 3
)
assert (
tyro.cli(
outer_no_prefix,
args="--x.branch 5".split(" "),
)
== 3
)


def test_alias() -> None:
"""Arguments with aliases."""

Expand Down
8 changes: 8 additions & 0 deletions tests/test_py311_generated/test_dcargs_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,14 @@ def main(device: torch.device) -> torch.device:
assert tyro.cli(main, args=["--device", "cpu"]) == torch.device("cpu")


def test_supports_inference_mode_decorator() -> None:
@torch.inference_mode()
def main(x: int, device: str) -> Tuple[int, str]:
return x, device

assert tyro.cli(main, args="--x 3 --device cuda".split(" ")) == (3, "cuda")


def test_torch_device_2() -> None:
assert tyro.cli(torch.device, args=["cpu"]) == torch.device("cpu")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,20 +436,23 @@ class Wrapper:
assert wrapper1 == tyro.extras.from_yaml(Wrapper, tyro.extras.to_yaml(wrapper1))


def test_superclass() -> None:
# https://github.com/brentyi/tyro/issues/7
@dataclasses.dataclass
class TypeA:
data: int

@dataclasses.dataclass
class TypeA:
data: int

@dataclasses.dataclass
class TypeASubclass(TypeA):
pass
@dataclasses.dataclass
class TypeASubclass(TypeA):
pass

@dataclasses.dataclass
class Wrapper:
subclass: TypeA

@dataclasses.dataclass
class Wrapper:
subclass: TypeA


def test_superclass() -> None:
# https://github.com/brentyi/tyro/issues/7

wrapper1 = Wrapper(TypeASubclass(3)) # Create Wrapper object.
assert wrapper1 == tyro.extras.from_yaml(Wrapper, tyro.extras.to_yaml(wrapper1))

0 comments on commit 29d7d20

Please sign in to comment.