diff --git a/tests/io_expectation.py b/tests/io_expectation.py index 0b9b4ff..48a4896 100644 --- a/tests/io_expectation.py +++ b/tests/io_expectation.py @@ -57,7 +57,7 @@ import io import re import sys -from typing import Callable, Optional +from typing import Callable, List, Optional, Sequence, Set, Tuple, Union CAMEL_CASE_RE = re.compile(r'(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])') @@ -71,7 +71,15 @@ class MalformedExpectationError(Error): """Error raised for malformed expectations.""" -def default_expectation(expected_io): +Expectation = Union[ + 'ExpectBase', + List['Expectation'], + Tuple['Expectation', ...], + Set['Expectation'], + str] + + +def default_expectation(expected_io: Expectation) -> 'ExpectBase': """Defines the behavior of python standard types when used as expectation. This is used to allow syntactic sugar, where expectation can be specified @@ -100,7 +108,7 @@ def default_expectation(expected_io): class ExpectBase(): """Base class for all expected response string matchers.""" - def __init__(self): + def __init__(self) -> None: self._consumed = False self._fulfilled = False self._saturated = False @@ -210,7 +218,7 @@ def times(self, max_repetition = max_repetition or min_repetition return Repeatedly(self, min_repetition, max_repetition) - def apply_transform(self, callback: Callable[[str], str]): + def apply_transform(self, callback: Callable[[str], str]) -> None: """Apply a transformation on the expectation definition. The ExpectedInputOutput class calls this function on every expectations @@ -250,7 +258,7 @@ class MultiLine(ExpectBase): instance: MultiLine(Contains('foo\nbar')) """ - def __init__(self, expected): + def __init__(self, expected: Expectation) -> None: super().__init__() print(f'Created {type(self).__name__}') @@ -285,7 +293,7 @@ def description(self, saturated: bool) -> str: class ExpectStringBase(ExpectBase): """Base class for all string expectations.""" - def __init__(self, expected: str): + def __init__(self, expected: str) -> None: super().__init__() if not isinstance(expected, str): raise MalformedExpectationError( @@ -352,7 +360,7 @@ def test_consume(self, string: str) -> bool: class LogicalOpExpectationBase(ExpectBase): """Base class for And and Or expectations""" - def __init__(self, args): + def __init__(self, args: Sequence[Expectation]) -> None: super().__init__() if any(isinstance(a, MultiLine) for a in args): @@ -365,7 +373,7 @@ def __init__(self, args): class And(LogicalOpExpectationBase): """Matcher succeeding when all its sub matcher succeed.""" - def __init__(self, *args): + def __init__(self, *args: Expectation): super().__init__(args) @property @@ -398,29 +406,29 @@ def description(self, saturated: bool) -> str: class Or(LogicalOpExpectationBase): """Matcher succeeding when one of its sub matcher succeed.""" - def __init__(self, *args): + def __init__(self, *args: Expectation): super().__init__(args) @property - def fulfilled(self): + def fulfilled(self) -> bool: return any(e.fulfilled for e in self._expected_list) @property - def saturated(self): + def saturated(self) -> bool: return all(e.saturated for e in self._expected_list) - def consume(self, string): + def consume(self, string: str) -> bool: self._consumed = True return any(e.consume(string) for e in self._expected_list) - def test_consume(self, string): + def test_consume(self, string: str) -> bool: return any(e.test_consume(string) for e in self._expected_list) - def apply_transform(self, callback: Callable[[str], str]): + def apply_transform(self, callback: Callable[[str], str]) -> None: for expected in self._expected_list: expected.apply_transform(callback) - def description(self, saturated): + def description(self, saturated: bool) -> str: parts = [a.description(saturated) for a in self._expected_list if not a.consumed or a.saturated == saturated] if len(parts) == 1: @@ -431,7 +439,7 @@ def description(self, saturated): class Not(ExpectBase): """Matcher succeeding when it's sub-matcher fails.""" - def __init__(self, expected): + def __init__(self, expected: Expectation) -> None: super().__init__() if isinstance(expected, MultiLine): raise MalformedExpectationError( @@ -439,26 +447,29 @@ def __init__(self, expected): self._expected = default_expectation(expected) self._thrifty = True - def consume(self, string): + def consume(self, string: str) -> bool: self._consumed = True self._fulfilled = not self._expected.consume(string) self._saturated = not self._expected.saturated return self._fulfilled - def test_consume(self, string): + def test_consume(self, string: str) -> bool: return not self._expected.test_consume(string) def apply_transform(self, callback: Callable[[str], str]): self._expected.apply_transform(callback) - def description(self, saturated): + def description(self, saturated: bool) -> str: return f'Not({self._expected.description(not saturated)})' class Repeatedly(ExpectBase): """Wraps an expectation to make it repeat a given number of times.""" - def __init__(self, sub_expectation, min_repetition=0, max_repetition=None): + def __init__(self, + sub_expectation: Expectation, + min_repetition: int = 0, + max_repetition: Optional[int] = None): super().__init__() if isinstance(sub_expectation, MultiLine): raise MalformedExpectationError( @@ -471,15 +482,15 @@ def __init__(self, sub_expectation, min_repetition=0, max_repetition=None): self._thrifty = self._sub_expectation.thrifty @property - def fulfilled(self): + def fulfilled(self) -> bool: return self._current_repetition >= self._min_repetition @property - def saturated(self): + def saturated(self) -> bool: return (self._max_repetition is not None and self._current_repetition >= self._max_repetition) - def consume(self, string): + def consume(self, string: str) -> bool: self._consumed = True result = self._current_expectation.consume(string) if self._current_expectation.fulfilled: @@ -487,20 +498,20 @@ def consume(self, string): self._current_expectation = copy.deepcopy(self._sub_expectation) return result - def test_consume(self, string): + def test_consume(self, string: str) -> bool: return self._current_expectation.test_consume(string) - def produce(self): + def produce(self) -> Optional[str]: result = self._current_expectation.produce() if self._current_expectation.saturated: self._current_repetition += 1 self._current_expectation = copy.deepcopy(self._sub_expectation) return result - def apply_transform(self, callback: Callable[[str], str]): + def apply_transform(self, callback: Callable[[str], str]) -> None: self._sub_expectation.apply_transform(callback) - def description(self, saturated): + def description(self, saturated: bool) -> str: arg1 = max(self._min_repetition - self._current_repetition, 0) arg2 = (self._max_repetition - self._current_repetition if self._max_repetition is not None else None) @@ -514,7 +525,7 @@ def description(self, saturated): class ExpectSequenceBase(ExpectBase): """Base class for all sequence-based expectations.""" - def __init__(self, *args): + def __init__(self, *args: Expectation) -> None: super().__init__() if any(isinstance(a, MultiLine) for a in args): raise MalformedExpectationError( @@ -522,18 +533,18 @@ def __init__(self, *args): self._expected_list = [default_expectation(expected) for expected in args] @property - def fulfilled(self): + def fulfilled(self) -> bool: return all(e.fulfilled for e in self._expected_list) @property - def saturated(self): + def saturated(self) -> bool: return all(e.saturated for e in self._expected_list) - def apply_transform(self, callback: Callable[[str], str]): + def apply_transform(self, callback: Callable[[str], str]) -> None: for expected in self._expected_list: expected.apply_transform(callback) - def description(self, saturated): + def description(self, saturated: bool) -> str: parts = [a.description(saturated) for a in self._expected_list if not a.consumed or a.saturated == saturated] if len(parts) == 1: @@ -546,7 +557,7 @@ def description(self, saturated): class InOrder(ExpectSequenceBase): """Sequence of expectations that must match in right order.""" - def consume(self, string): + def consume(self, string: str) -> bool: self._consumed = True to_consume = None for i, expected in enumerate(self._expected_list): @@ -575,7 +586,7 @@ def consume(self, string): return False - def test_consume(self, string): + def test_consume(self, string: str) -> bool: for expected in self._expected_list: if expected.test_consume(string): return True @@ -583,7 +594,7 @@ def test_consume(self, string): return False return False - def produce(self): + def produce(self) -> Optional[str]: for i, expected in enumerate(self._expected_list): result = expected.produce() if result: @@ -601,7 +612,7 @@ def produce(self): class AnyOrder(ExpectSequenceBase): """Sequence of expectation that can match in any order.""" - def consume(self, string): + def consume(self, string: str) -> bool: self._consumed = True to_consume = None for expected in self._expected_list: @@ -619,11 +630,11 @@ def consume(self, string): return False - def test_consume(self, string): + def test_consume(self, string: str) -> bool: return any(expected.test_consume(string) for expected in self._expected_list) - def produce(self): + def produce(self) -> Optional[str]: for expected in self._expected_list: result = expected.produce() if result: @@ -633,7 +644,7 @@ def produce(self): return None -def somewhere(expectation): +def somewhere(expectation: Expectation) -> ExpectBase: """Match an expectation anywhere in a document.""" return InOrder(Anything().repeatedly(), expectation, @@ -643,44 +654,44 @@ def somewhere(expectation): class Url(ExpectBase): """Matches a URL. This matcher won't replace '/' to '\' on Windows.""" - def __init__(self, sub_expectation): + def __init__(self, sub_expectation: Expectation) -> None: super().__init__() if isinstance(sub_expectation, MultiLine): raise MalformedExpectationError( 'MultiLine cannot be used as a sub-expectation.') self._expected = default_expectation(sub_expectation) - def consume(self, string): + def consume(self, string: str) -> bool: self._consumed = True self._fulfilled = self._expected.consume(string) self._saturated = self._expected.saturated return self._fulfilled - def test_consume(self, string): + def test_consume(self, string: str) -> bool: return self._expected.test_consume(string) def apply_transform(self, callback: Callable[[str], str]) -> None: del callback # Unused. - def description(self, saturated): + def description(self, saturated: bool) -> str: return f'Url({self._expected.description(saturated)})' class Reply(ExpectBase): """Expects a read to the input pipe and replies with a specific string.""" - def __init__(self, reply_string): + def __init__(self, reply_string: str) -> None: super().__init__() self._reply_string = reply_string - def produce(self): + def produce(self) -> Optional[str]: self._fulfilled = self._saturated = True return self._reply_string - def _consume(self, line): - raise AssertionError(f'Expecting user input but got output line: {line}') + def consume(self, string: str) -> bool: + raise AssertionError(f'Expecting user input but got output line: {string}') - def description(self, saturated): + def description(self, saturated: bool) -> str: del saturated # Unused. name = type(self).__name__ args = self._reply_string @@ -690,7 +701,7 @@ def description(self, saturated): class ExpectedInputOutput(): """File object for overriding stdin/out, mocking inputs & checking outputs.""" - def __init__(self): + def __init__(self) -> None: self.set_transform_fn(None) self.set_expected_io(None) self._original_stdin = sys.stdin @@ -705,7 +716,8 @@ def close(self) -> None: self._expected_io = None self._cmd_output = io.StringIO() - def set_transform_fn(self, transform_fn: Optional[Callable[[str], str]]): + def set_transform_fn(self, + transform_fn: Optional[Callable[[str], str]]) -> None: """Callback to transform all expectations passed in set_expected_io. Useful for 'patching' all expectations with the same transformation. A @@ -719,7 +731,7 @@ def set_transform_fn(self, transform_fn: Optional[Callable[[str], str]]): """ self._transform_fn = transform_fn - def set_expected_io(self, expected_io): + def set_expected_io(self, expected_io: Optional[Expectation]) -> None: """Set an expectation for the next sequence of IOs. The expected IO can be specified as an instance of an ExpectBase child @@ -736,7 +748,7 @@ def set_expected_io(self, expected_io): self._expected_io = self._patch_expected_io(expected_io) self._cmd_output = io.StringIO() - def write(self, string): + def write(self, string: str) -> None: """File object 'write' method, matched against the next expected output. Args: @@ -745,11 +757,11 @@ def write(self, string): self._original_stdout.write(string) self._cmd_output.write(string) - def flush(self): + def flush(self) -> None: """File object 'flush' method.""" self._original_stdout.flush() - def readline(self): + def readline(self) -> str: """File object 'readline' method, replied using the next expected input. Returns: @@ -791,7 +803,7 @@ def assert_expectations_fulfilled(self) -> io.StringIO: self.set_expected_io(None) return cmd_output - def assert_output_was(self, expected_output): + def assert_output_was(self, expected_output: Expectation) -> None: """Asserts that the previous outputs matched the specified expectation. Args: @@ -804,7 +816,8 @@ def assert_output_was(self, expected_output): self._expected_io = self._patch_expected_io(expected_output) self.assert_expectations_fulfilled() - def _patch_expected_io(self, expected_io): + def _patch_expected_io( + self, expected_io: Optional[Expectation]) -> Optional[ExpectBase]: """Patch the specified expectation, applying defaults and transforms. Args: @@ -814,14 +827,16 @@ def _patch_expected_io(self, expected_io): Returns: Instance of an ExpectBase subclass. """ - patched_expected_io = default_expectation(expected_io) + patched_expected_io = ( + default_expectation(expected_io) + if expected_io is not None else None) if patched_expected_io and self._transform_fn: patched_expected_io.apply_transform(self._transform_fn) return patched_expected_io - def _match_pending_outputs(self): + def _match_pending_outputs(self) -> None: """Match any pending IO against the expectations. Raises: