From e9d1adea051e8cf07936b1fa4db69a3c5cd0ec38 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Wed, 7 Sep 2022 21:10:42 -0700 Subject: [PATCH] now works on descriptors too! --- pyproject.toml | 2 +- src/trouting/core.py | 91 ++++++++++++++++++++++++++++------------ tests/test_decorators.py | 84 +++++++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 27 deletions(-) create mode 100644 tests/test_decorators.py diff --git a/pyproject.toml b/pyproject.toml index 57b8230..d30567e 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "trouting" -version = "0.2.2" +version = "0.3.0" description = "Trouting (short for Type Routing) is a simple class decorator that allows to define multiple interfaces for a method that behave differently depending on input types." authors = [ {name = "Luca Soldaini", email = "luca@soldaini.net" } diff --git a/src/trouting/core.py b/src/trouting/core.py index d34e88c..0387e3c 100644 --- a/src/trouting/core.py +++ b/src/trouting/core.py @@ -6,12 +6,11 @@ Callable, Dict, Generic, - Optional, Sequence, Tuple, - Type, TypeVar, Union, + cast, ) from typing_extensions import Concatenate, ParamSpec @@ -47,7 +46,7 @@ def add_one_str(self, a: str) -> str: interfaces: Dict[Tuple[type, ...], Callable[Concatenate[Any, P], R]] def __init__( - self, interfaced_method: Callable[Concatenate[Any, P], R] + self, fallback_method: Callable[Concatenate[Any, P], R] ) -> None: """Create an Interface object. @@ -57,9 +56,8 @@ def __init__( """ self.interfaces = {} self.bounded_args = None - self._interfaced_method = interfaced_method - self._method_signature = inspect.signature(interfaced_method) - self._obj = None + self.fallback_method = fallback_method + self.is_descriptor = inspect.ismethoddescriptor(fallback_method) def _expand_interface_combinations( self, nested_interface_spec: Dict[str, Union[type, Tuple[type, ...]]] @@ -98,7 +96,7 @@ def add_interface( if self.bounded_args is None: self.bounded_args = current_interface_args elif self.bounded_args != current_interface_args: - raise ValueError( + raise TypeError( "All interfaces must have the same arguments; the current " f"interface has arguments {current_interface_args}, but the " f"previous interface has arguments {self.bounded_args}" @@ -107,45 +105,86 @@ def add_interface( def _add_interface( method: Callable[Concatenate[Any, P], R] ) -> "trouting": + if self.is_descriptor: + if not inspect.ismethoddescriptor(method): + raise TypeError( + "All interfaces must be descriptors; the current " + "interface is a function." + ) + elif not isinstance(self.fallback_method, type(method)): + raise TypeError( + "All interfaces must be of the same type; the current " + f"interface is a {type(method)}, but the previous " + f"interface is {type(self.fallback_method)}." + ) + for interface_spec in interface_specs: # register the same method for all types in the interface spec - self.interfaces[tuple(interface_spec.values())] = method + # have to add an ignore because pyright is being a bit too + # clever here. + self.interfaces[ # pyright: ignore + tuple(interface_spec.values()) + ] = method + return self return _add_interface - def __get__( - self, obj: Any, type: Optional[Type] = None - ) -> Callable[Concatenate[P], R]: + def __get__(self, obj: Any, type: Any) -> Callable[Concatenate[P], R]: """Return a bound method that calls the correct interface.""" - return partial(self.__call__, __obj__=obj) + return partial( + self.__call__, __trouting_obj__=obj, __trouting_type__=type + ) + + def _bound_method(self, method: Any, obj: Any, cls: Any) -> Callable: + if self.is_descriptor: + bound_method = method.__get__(obj, cast(type, cls)) + else: + # populate the first argument with the object or class here + bound_method = partial(method, obj or cls) + return bound_method def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: """Call the interfaced method with the correct interface.""" - if (obj := kwargs.pop("__obj__", MISSING)) is MISSING: + if ( + __trouting_obj__ := kwargs.pop("__trouting_obj__", MISSING) + ) is MISSING: raise ValueError( - "__obj__ is required; `Interface._run_interface` " + "__trouting_obj__ is required; `Interface._run_interface` " + "was improperly called; You might have called a trouted " + "method in an invalid way; If you think you are using this " + "library correctly, please file a bug report." + ) + if ( + __trouting_type__ := kwargs.pop("__trouting_type__", MISSING) + ) is MISSING: + raise ValueError( + "__trouting_type__ is required; `Interface._run_interface` " "was improperly called; You might have called a trouted " "method in an invalid way; If you think you are using this " "library correctly, please file a bug report." ) - if self.bounded_args is None: - # no interfaces have been added, so we fall back to the default - return self._interfaced_method(obj, *args, **kwargs) + bounded_fallback_method = self._bound_method( + self.fallback_method, __trouting_obj__, __trouting_type__ + ) - sig_vals = self._method_signature.bind(self, *args, **kwargs) - method_to_call = None + sig_vals = inspect.signature(bounded_fallback_method).bind( + *args, **kwargs + ) - current_types = ( + current_types = tuple( type(sig_vals.arguments[arg_name]) - for arg_name in self.bounded_args + for arg_name in (self.bounded_args or tuple()) ) - # fall back to the default method if we didn't find anything - method_to_call = self.interfaces.get( - tuple(current_types), self._interfaced_method - ) + method_to_call = self.interfaces.get(current_types, None) + if method_to_call is None: + method_to_call = bounded_fallback_method + else: + method_to_call = self._bound_method( + method_to_call, __trouting_obj__, __trouting_type__ + ) - return method_to_call(obj, *args, **kwargs) + return method_to_call(*args, **kwargs) diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 0000000..d83d05e --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,84 @@ +from typing import Any +from unittest import TestCase + +from trouting import trouting + + +class TroutedClass: + @trouting + @classmethod + def add_one(cls, a: Any) -> Any: + raise TypeError(f"Type {type(a)} not supported for +1") + + @add_one.add_interface(a=int) + @classmethod + def add_one_int(cls, a: int) -> int: + return a + 1 + + @add_one.add_interface(a=str) + @classmethod + def add_one_str(cls, a: str) -> str: + return a + "1" + + @trouting + def add_two(self, a: Any) -> Any: + raise TypeError(f"Type {type(a)} not supported for +2") + + @add_two.add_interface(a=int) + def add_two_int(self, a: int) -> int: + return a + 2 + + @add_two.add_interface(a=str) + def add_two_str(self, a: str) -> str: + return a + "2" + + @trouting + @staticmethod + def add_three(a: Any) -> Any: + raise TypeError(f"Type {type(a)} not supported for +3") + + @add_three.add_interface(a=int) + @staticmethod + def add_three_int(a: int) -> int: + return a + 3 + + @add_three.add_interface(a=str) + @staticmethod + def add_three_str(a: str) -> str: + return a + "3" + + +class TestDecorators(TestCase): + def test_classmethod(self): + self.assertEqual(TroutedClass.add_one(1), 2) + self.assertEqual(TroutedClass.add_one("1"), "11") + + def test_instance_method(self): + self.assertEqual(TroutedClass().add_two(1), 3) + self.assertEqual(TroutedClass().add_two("1"), "12") + + def test_staticmethod(self): + # TODO[soldni]: need to fix typing annotation for trouting + # so that pylance doesn't freak out when using staticmethod + self.assertEqual(TroutedClass.add_three(1), 4) # pyright: ignore + self.assertEqual(TroutedClass.add_three("1"), "13") # pyright: ignore + + def test_raise_error_uneven_interfaces(self): + class _: + @trouting + @classmethod + def add_one(cls, a: Any) -> Any: + raise TypeError(f"Type {type(a)} not supported for +1") + + with self.assertRaises(TypeError): + + @add_one.add_interface(a=int) + def add_one_int(cls, a: int) -> int: + return a + 1 + + with self.assertRaises(TypeError): + # Type ignore because this is intentionally wrong + @add_one.add_interface(a=str) # type: ignore + @staticmethod + def add_one_str(a: str) -> str: + return a + "1"