diff --git a/pyproject.toml b/pyproject.toml index ce9d61afa3..bfd4dba0b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ classifiers = [ ] dependencies = [ "click >= 8.0.0", - "typing-extensions >= 3.7.4.3", + "typing-extensions >= 4.6.0", ] readme = "README.md" [project.urls] diff --git a/requirements-tests.txt b/requirements-tests.txt index d58de0d9b4..70ac80fc2c 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -5,7 +5,7 @@ pytest-cov >=2.10.0,<6.0.0 coverage[toml] >=6.2,<8.0 pytest-xdist >=1.32.0,<4.0.0 pytest-sugar >=0.9.4,<1.1.0 -mypy ==1.4.1 +mypy >=1.10.1 ruff ==0.6.3 # Needed explicitly by typer-slim rich >=10.11.0 diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000000..e3518c6c75 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,139 @@ +from datetime import datetime +from enum import Enum +from pathlib import Path +from uuid import UUID + +import click +import pytest +from typer.main import get_click_type +from typer.models import FileBinaryRead, FileTextWrite, ParameterInfo +from typing_extensions import TypeAliasType + + +def test_get_click_type_with_custom_click_type(): + custom_click_type = click.INT + param_info = ParameterInfo(click_type=custom_click_type) + result = get_click_type(annotation=int, parameter_info=param_info) + assert result is custom_click_type + + +def test_get_click_type_with_custom_parser(): + def mock_parser(x): + return 42 + + param_info = ParameterInfo(parser=mock_parser) + result = get_click_type(annotation=int, parameter_info=param_info) + assert isinstance(result, click.types.FuncParamType) + assert result.convert("42", None, None) == 42 + + +def test_get_click_type_with_str_annotation(): + param_info = ParameterInfo() + result = get_click_type(annotation=str, parameter_info=param_info) + assert result is click.STRING + + +def test_get_click_type_with_int_annotation_no_min_max(): + param_info = ParameterInfo() + result = get_click_type(annotation=int, parameter_info=param_info) + assert result is click.INT + + +def test_get_click_type_with_int_annotation_with_min_max(): + param_info = ParameterInfo(min=10, max=100) + result = get_click_type(annotation=int, parameter_info=param_info) + assert isinstance(result, click.IntRange) + assert result.min == 10 + assert result.max == 100 + + +def test_get_click_type_with_float_annotation_no_min_max(): + param_info = ParameterInfo() + result = get_click_type(annotation=float, parameter_info=param_info) + assert result is click.FLOAT + + +def test_get_click_type_with_float_annotation_with_min_max(): + param_info = ParameterInfo(min=0.1, max=10.5) + result = get_click_type(annotation=float, parameter_info=param_info) + assert isinstance(result, click.FloatRange) + assert result.min == 0.1 + assert result.max == 10.5 + + +def test_get_click_type_with_bool_annotation(): + param_info = ParameterInfo() + result = get_click_type(annotation=bool, parameter_info=param_info) + assert result is click.BOOL + + +def test_get_click_type_with_uuid_annotation(): + param_info = ParameterInfo() + result = get_click_type(annotation=UUID, parameter_info=param_info) + assert result is click.UUID + + +def test_get_click_type_with_datetime_annotation(): + param_info = ParameterInfo(formats=["%Y-%m-%d"]) + result = get_click_type(annotation=datetime, parameter_info=param_info) + assert isinstance(result, click.DateTime) + assert result.formats == ["%Y-%m-%d"] + + +def test_get_click_type_with_path_annotation(): + param_info = ParameterInfo(resolve_path=True) + result = get_click_type(annotation=Path, parameter_info=param_info) + assert isinstance(result, click.Path) + assert result.resolve_path is True + + +def test_get_click_type_with_enum_annotation(): + class Color(Enum): + RED = "red" + BLUE = "blue" + + param_info = ParameterInfo() + result = get_click_type(annotation=Color, parameter_info=param_info) + assert isinstance(result, click.Choice) + assert result.choices == ["red", "blue"] + + +def test_get_click_type_with_file_text_write_annotation(): + param_info = ParameterInfo(mode="w", encoding="utf-8") + result = get_click_type(annotation=FileTextWrite, parameter_info=param_info) + assert isinstance(result, click.File) + assert result.mode == "w" + assert result.encoding == "utf-8" + + +def test_get_click_type_with_file_binary_read_annotation(): + param_info = ParameterInfo(mode="rb") + result = get_click_type(annotation=FileBinaryRead, parameter_info=param_info) + assert isinstance(result, click.File) + assert result.mode == "rb" + + +def test_get_click_type_with_type_alias_type(): + # define TypeAliasType + Name = TypeAliasType(name="Name", value=str) + Surname = TypeAliasType(name="Surname", value=Name) + + param_info = ParameterInfo() + result = get_click_type(annotation=Name, parameter_info=param_info) + assert result is click.STRING + + # recursive types + param_info = ParameterInfo() + result = get_click_type(annotation=Surname, parameter_info=param_info) + assert result is click.STRING + + +def test_get_click_type_with_unsupported_type(): + class UnsupportedType: + pass + + param_info = ParameterInfo() + with pytest.raises( + RuntimeError, match="Type not yet supported: " + ): + get_click_type(annotation=UnsupportedType, parameter_info=param_info) diff --git a/typer/main.py b/typer/main.py index a621bda6ad..bcf7535a47 100644 --- a/typer/main.py +++ b/typer/main.py @@ -8,11 +8,21 @@ from pathlib import Path from traceback import FrameSummary, StackSummary from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) from uuid import UUID import click -from typing_extensions import get_args, get_origin +from typing_extensions import TypeAliasType, get_args, get_origin from ._typing import is_union from .completion import get_completion_inspect_parameters @@ -43,7 +53,7 @@ Required, TyperInfo, ) -from .utils import get_params_from_function +from .utils import get_original_type, get_params_from_function try: import rich @@ -710,6 +720,9 @@ def wrapper(**kwargs: Any) -> Any: def get_click_type( *, annotation: Any, parameter_info: ParameterInfo ) -> click.ParamType: + if isinstance(annotation, TypeAliasType): + annotation = get_original_type(annotation) + if parameter_info.click_type is not None: return parameter_info.click_type diff --git a/typer/utils.py b/typer/utils.py index 93c407447e..32c9511ac0 100644 --- a/typer/utils.py +++ b/typer/utils.py @@ -3,10 +3,20 @@ from copy import copy from typing import Any, Callable, Dict, List, Tuple, Type, cast -from typing_extensions import Annotated, get_args, get_origin, get_type_hints +from typing_extensions import ( + Annotated, + TypeAliasType, + TypeVar, + get_args, + get_origin, + get_type_hints, +) from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta +T = TypeVar("T") +TypeAliasTypeVar = TypeAliasType("TypeAliasTypeVar", value=T, type_params=(T,)) + def _param_type_to_user_string(param_type: Type[ParameterInfo]) -> str: # Render a `ParameterInfo` subclass for use in error messages. @@ -189,3 +199,24 @@ def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]: name=param.name, default=default, annotation=annotation ) return params + + +def get_original_type(alias: TypeAliasTypeVar[T]) -> T: + """Return the original type of an alias. + + Examples + -------- + >>> Name = TypeAliasType(name="Name", value=str) + >>> Surname = TypeAliasType(name="Surname", value=Name) + >>> get_original_type(Name) + str + >>> get_original_type(Surname) + str + >>> get_original_type(int) + int + """ + otype = alias + while isinstance(otype, TypeAliasType): + otype = otype.__value__ + + return otype