Skip to content

Commit

Permalink
Enable strict type checking for complex data types in check_contracts (
Browse files Browse the repository at this point in the history
  • Loading branch information
CulmoneY authored Nov 5, 2024
1 parent 5d50aa3 commit ba8cfd8
Show file tree
Hide file tree
Showing 3 changed files with 371 additions and 25 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
- Fixed issue where parallel assignment statements and assignment to multiple targets were not checked by `redundant_assignment_checker`
- Fixed issue where annotated assignment statements were not checked by `redundant_assignment_checker`
- Fixed issue where empty preconditions were preventing CFGs from being generated
- Added strict numeric type checking to enforce type distinctions across the entire numeric hierarchy, including complex numbers.
- Added strict type checking support for nested and union types (e.g., `list[int]`, `dict[float, int]`, `Union[int, float]`)

### 🔧 Internal changes

Expand Down
78 changes: 70 additions & 8 deletions python_ta/contracts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,17 @@
import sys
import typing
from types import CodeType, FunctionType, ModuleType
from typing import Any, Callable, Optional, TypeVar, Union, overload
from typing import (
Any,
Callable,
Collection,
Optional,
TypeVar,
Union,
get_args,
get_origin,
overload,
)

import wrapt
from typeguard import CollectionCheckStrategy, TypeCheckError, check_type
Expand Down Expand Up @@ -321,17 +331,69 @@ def _check_function_contracts(wrapped, instance, args, kwargs):


def check_type_strict(argname: str, value: Any, expected_type: type) -> None:
"""Ensure that ``value`` matches ``expected_type``.
"""Ensure that `value` matches ``expected_type`` with strict type checking.
Differentiates between:
- float vs. int
- bool vs. int
This function enforces strict type distinctions within the numeric hierarchy (bool, int, float,
complex), ensuring that the type of value is exactly the same as expected_type.
"""
if ENABLE_CONTRACT_CHECKING:
if (type(value) is int and expected_type is float) or (
type(value) is bool and expected_type is int
if not ENABLE_CONTRACT_CHECKING:
return
try:
_check_inner_type(argname, value, expected_type)
except (TypeError, TypeCheckError):
raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")


def _check_inner_type(argname: str, value: Any, expected_type: type) -> None:
"""Recursively checks if `value` matches `expected_type` for strict type validation, specifically supports checking
collections (list[int], dicts[float]) and Union types (bool | int).
"""
inner_types = get_args(expected_type)
outer_type = get_origin(expected_type)
if outer_type is None:
if (
(type(value) is bool and expected_type in {int, float, complex})
or (type(value) is int and expected_type in {float, complex})
or (type(value) is float and expected_type is complex)
):
raise TypeError(
f"type of {argname} must be {expected_type}; got {type(value).__name__} instead"
)
else:
check_type(
value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
)
elif outer_type is typing.Union:
for inner_type in inner_types:
try:
_check_inner_type(argname, value, inner_type)
return
except (TypeError, TypeCheckError):
pass
raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
elif outer_type in {list, set}:
if isinstance(value, outer_type):
for item in value:
_check_inner_type(argname, item, inner_types[0])
else:
raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
elif outer_type is dict:
if isinstance(value, dict):
for key, item in value.items():
_check_inner_type(argname, key, inner_types[0])
_check_inner_type(argname, item, inner_types[1])
else:
raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
elif outer_type is tuple:
if isinstance(value, tuple) and len(inner_types) == 2 and inner_types[1] is Ellipsis:
for item in value:
_check_inner_type(argname, item, inner_types[0])
elif isinstance(value, tuple) and len(value) == len(inner_types):
for item, inner_type in zip(value, inner_types):
_check_inner_type(argname, item, inner_type)
else:
raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
else:
check_type(
value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
)
Expand Down
Loading

0 comments on commit ba8cfd8

Please sign in to comment.