Skip to content

Commit

Permalink
Add typing
Browse files Browse the repository at this point in the history
  • Loading branch information
kaste committed Feb 3, 2025
1 parent ff54160 commit ee80823
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 106 deletions.
120 changes: 69 additions & 51 deletions mockito/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

from __future__ import annotations
from abc import ABC
import os
import inspect
Expand All @@ -28,7 +29,12 @@
from . import verification as verificationModule
from .utils import contains_strict

from typing import Any, Callable, Deque, Dict, Tuple
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Callable, NoReturn, Self, TypeVar, TYPE_CHECKING
from .mocking import Mock
T = TypeVar('T')


class InvocationError(AttributeError):
Expand All @@ -45,15 +51,15 @@ class AnswerError(AttributeError):


class Invocation(object):
def __init__(self, mock, method_name):
def __init__(self, mock: Mock, method_name: str) -> None:
self.mock = mock
self.method_name = method_name
self.strict = mock.strict

self.params: Tuple[Any, ...] = ()
self.named_params: Dict[str, Any] = {}
self.params: tuple[Any, ...] = ()
self.named_params: dict[str, Any] = {}

def _remember_params(self, params, named_params):
def _remember_params(self, params: tuple, named_params: dict) -> None:
self.params = params
self.named_params = named_params

Expand All @@ -68,27 +74,29 @@ def __repr__(self):


class RealInvocation(Invocation, ABC):
def __init__(self, mock, method_name):
def __init__(self, mock: Mock, method_name: str) -> None:
super(RealInvocation, self).__init__(mock, method_name)
self.verified = False
self.verified_inorder = False


class RememberedInvocation(RealInvocation):
def ensure_mocked_object_has_method(self, method_name):
def ensure_mocked_object_has_method(self, method_name: str) -> None:
if not self.mock.has_method(method_name):
raise InvocationError(
"You tried to call a method '%s' the object (%s) doesn't "
"have." % (method_name, self.mock.mocked_obj))

def ensure_signature_matches(self, method_name, args, kwargs):
def ensure_signature_matches(
self, method_name: str, args: tuple, kwargs: dict
) -> None:
sig = self.mock.get_signature(method_name)
if not sig:
return

signature.match_signature(sig, args, kwargs)

def __call__(self, *params, **named_params):
def __call__(self, *params: Any, **named_params: Any) -> Any | None:
if self.mock.eat_self(self.method_name):
params_without_first_arg = params[1:]
else:
Expand Down Expand Up @@ -141,7 +149,7 @@ class RememberedProxyInvocation(RealInvocation):
Calls method on original object and returns it's return value.
"""
def __call__(self, *params, **named_params):
def __call__(self, *params: Any, **named_params: Any) -> Any:
self._remember_params(params, named_params)
self.mock.remember(self)
obj = self.mock.spec
Expand Down Expand Up @@ -174,7 +182,7 @@ def compare(p1, p2):
return False
return True

def capture_arguments(self, invocation):
def capture_arguments(self, invocation: RealInvocation) -> None:
"""Capture arguments of `invocation` into "capturing" matchers of self.
This is used in conjunction with "capturing" matchers like
Expand Down Expand Up @@ -204,7 +212,7 @@ def capture_arguments(self, invocation):
p1.capture_value(p2)


def _remember_params(self, params, named_params):
def _remember_params(self, params: tuple, named_params: dict) -> None:
if (
contains_strict(params, Ellipsis)
and (params[-1] is not Ellipsis or named_params)
Expand All @@ -231,7 +239,7 @@ def wrap(p):
# Note: matches(a, b) does not imply matches(b, a) because
# the left side might contain wildcards (like Ellipsis) or matchers.
# In its current form the right side is a concrete call signature.
def matches(self, invocation): # noqa: C901 (too complex)
def matches(self, invocation: Invocation) -> bool: # noqa: C901, E501 (too complex)
if self.method_name != invocation.method_name:
return False

Expand Down Expand Up @@ -294,11 +302,16 @@ class VerifiableInvocation(MatchingInvocation):
call. But the `__call__` is essentially virtual and can contain
placeholders and matchers.
"""
def __init__(self, mock, method_name, verification):
def __init__(
self,
mock: Mock,
method_name: str,
verification: verificationModule.VerificationMode
) -> None:
super(VerifiableInvocation, self).__init__(mock, method_name)
self.verification = verification

def __call__(self, *params, **named_params):
def __call__(self, *params: Any, **named_params: Any) -> None:
self._remember_params(params, named_params)
matched_invocations = []
for invocation in self.mock.invocations:
Expand All @@ -321,7 +334,9 @@ def __call__(self, *params, **named_params):
stub.allow_zero_invocations = True


def verification_has_lower_bound_of_zero(verification):
def verification_has_lower_bound_of_zero(
verification: verificationModule.VerificationMode | None
) -> bool:
if (
isinstance(verification, verificationModule.Times)
and verification.wanted_count == 0
Expand Down Expand Up @@ -372,7 +387,13 @@ class StubbedInvocation(MatchingInvocation):
there is no "new" keyword in Python.)
"""
def __init__(self, mock, method_name, verification=None, strict=None):
def __init__(
self,
mock: Mock,
method_name: str,
verification: verificationModule.VerificationMode | None = None,
strict: bool | None = None
) -> None:
super(StubbedInvocation, self).__init__(mock, method_name)

#: Holds the verification set up via `expect`.
Expand All @@ -391,26 +412,25 @@ def __init__(self, mock, method_name, verification=None, strict=None):

#: Set if `verifyStubbedInvocationsAreUsed` should pass, regardless
#: of any factual invocation. E.g. set by `expect(..., times=0)`
if verification_has_lower_bound_of_zero(verification):
self.allow_zero_invocations = True
else:
self.allow_zero_invocations = False

self.allow_zero_invocations: bool = \
verification_has_lower_bound_of_zero(verification)

def ensure_mocked_object_has_method(self, method_name):
def ensure_mocked_object_has_method(self, method_name: str) -> None:
if not self.mock.has_method(method_name):
raise InvocationError(
"You tried to stub a method '%s' the object (%s) doesn't "
"have." % (method_name, self.mock.mocked_obj))

def ensure_signature_matches(self, method_name, args, kwargs):
def ensure_signature_matches(
self, method_name: str, args: tuple, kwargs: dict
) -> None:
sig = self.mock.get_signature(method_name)
if not sig:
return

signature.match_signature_allowing_placeholders(sig, args, kwargs)

def __call__(self, *params, **named_params):
def __call__(self, *params: Any, **named_params: Any) -> AnswerSelector:
if self.strict:
self.ensure_mocked_object_has_method(self.method_name)
self.ensure_signature_matches(
Expand All @@ -421,13 +441,13 @@ def __call__(self, *params, **named_params):
self.mock.finish_stubbing(self)
return AnswerSelector(self)

def forget_self(self):
def forget_self(self) -> None:
self.mock.forget_stubbed_invocation(self)

def add_answer(self, answer):
def add_answer(self, answer: Callable) -> None:
self.answers.add(answer)

def answer_first(self, *args, **kwargs):
def answer_first(self, *args: Any, **kwargs: Any) -> Any:
self.used += 1
return self.answers.answer(*args, **kwargs)

Expand Down Expand Up @@ -466,54 +486,52 @@ def should_answer(self, invocation: RememberedInvocation) -> None:
# to get verified 'implicitly', on-the-go, so we set this flag here.
invocation.verified = True

def verify(self):
def verify(self) -> None:
if self.verification:
self.verification.verify(self, self.used)

def check_used(self):
def check_used(self) -> None:
if not self.allow_zero_invocations and self.used < len(self.answers):
raise verificationModule.VerificationError(
"\nUnused stub: %s" % self)


def return_(value):
def answer(*args, **kwargs):
def return_(value: T) -> Callable[..., T]:
def answer(*args, **kwargs) -> T:
return value
return answer

def raise_(exception):
def answer(*args, **kwargs):
def raise_(exception: Exception | type[Exception]) -> Callable[..., NoReturn]:
def answer(*args, **kwargs) -> NoReturn:
raise exception
return answer


def discard_self(function):
def function_without_self(*args, **kwargs):
def discard_self(function: Callable[..., T]) -> Callable[..., T]:
def function_without_self(*args, **kwargs) -> T:
args = args[1:]
return function(*args, **kwargs)

return function_without_self


class AnswerSelector(object):
def __init__(self, invocation):
def __init__(self, invocation: StubbedInvocation) -> None:
self.invocation = invocation
self.discard_first_arg = \
invocation.mock.eat_self(invocation.method_name)

def thenReturn(self, *return_values):
def thenReturn(self, *return_values: Any) -> Self:
for return_value in return_values or (None,):
answer = return_(return_value)
self.__then(answer)
return self

def thenRaise(self, *exceptions):
def thenRaise(self, *exceptions: Exception | type[Exception]) -> Self:
for exception in exceptions or (Exception,):
answer = raise_(exception)
self.__then(answer)
return self

def thenAnswer(self, *callables):
def thenAnswer(self, *callables: Callable) -> Self:
if not callables:
raise TypeError("No answer function provided")
for callable in callables:
Expand All @@ -523,7 +541,7 @@ def thenAnswer(self, *callables):
self.__then(answer)
return self

def thenCallOriginalImplementation(self):
def thenCallOriginalImplementation(self) -> Self:
answer = self.invocation.mock.get_original_method(
self.invocation.method_name
)
Expand All @@ -545,36 +563,36 @@ def thenCallOriginalImplementation(self):
self.__then(answer)
return self

def __then(self, answer):
def __then(self, answer: Callable) -> None:
self.invocation.add_answer(answer)

def __enter__(self):
def __enter__(self) -> None:
pass

def __exit__(self, *exc_info):
def __exit__(self, *exc_info) -> None:
self.invocation.verify()
if os.environ.get("MOCKITO_CONTEXT_MANAGERS_CHECK_USAGE", "1") == "1":
self.invocation.check_used()
self.invocation.forget_self()


class CompositeAnswer(object):
def __init__(self):
def __init__(self) -> None:
#: Container for answers, which are just ordinary callables
self.answers: Deque[Callable] = deque()
self.answers: deque[Callable] = deque()

#: Counter for the maximum answers we ever had
self.answer_count = 0

def __len__(self):
def __len__(self) -> int:
# The minimum is '1' bc we always have a default answer of 'None'
return max(1, self.answer_count)

def add(self, answer):
def add(self, answer: Callable) -> None:
self.answer_count += 1
self.answers.append(answer)

def answer(self, *args, **kwargs):
def answer(self, *args: Any, **kwargs: Any) -> Any:
if len(self.answers) == 0:
return None

Expand Down
15 changes: 10 additions & 5 deletions mockito/mock_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

from __future__ import annotations
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .mocking import Mock


class MockRegistry:
Expand All @@ -30,26 +35,26 @@ class MockRegistry:
def __init__(self):
self.mocks = IdentityMap()

def register(self, obj, mock):
def register(self, obj: object, mock: Mock) -> None:
self.mocks[obj] = mock

def mock_for(self, obj):
def mock_for(self, obj: object) -> Mock | None:
return self.mocks.get(obj, None)

def unstub(self, obj):
def unstub(self, obj: object) -> None:
try:
mock = self.mocks.pop(obj)
except KeyError:
pass
else:
mock.unstub()

def unstub_all(self):
def unstub_all(self) -> None:
for mock in self.get_registered_mocks():
mock.unstub()
self.mocks.clear()

def get_registered_mocks(self):
def get_registered_mocks(self) -> list[Mock]:
return self.mocks.values()


Expand Down
Loading

0 comments on commit ee80823

Please sign in to comment.