diff --git a/.travis.yml b/.travis.yml index 4df453d6..216a988a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -27,8 +27,8 @@ script: - naz-cli --version && naz-cli --help - naz-cli --config examples/example_config.json --dry-run - coverage erase - - export CI_ENVIRONMENT=Yes && coverage run --omit="*tests*,*cli/test_*,*examples/*,*.virtualenvs/*,*virtualenv/*,*.venv/*,*__init__*" -m unittest discover -v -s . && bash <(curl -s https://codecov.io/bash) - - coverage report --show-missing --fail-under=83 + - export CI_ENVIRONMENT=Yes && coverage run --omit="*tests*,*examples/*,*.virtualenvs/*,*virtualenv/*,*.venv/*,*__init__*" -m unittest discover -v -s . && bash <(curl -s https://codecov.io/bash) + - coverage report --show-missing --fail-under=84 - | git remote set-branches --add origin master # https://github.com/travis-ci/travis-ci/issues/6069 git fetch diff --git a/CHANGELOG.md b/CHANGELOG.md index 55aba126..2296d200 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ most recent version is listed first. - cleanly handle termination signals like `SIGTERM`: https://github.com/komuw/naz/pull/106 - validate `naz.Client` arguments: https://github.com/komuw/naz/pull/108 - remove ability to bring your own eventloop: https://github.com/komuw/naz/pull/111 +- make `naz` more fault tolerant: https://github.com/komuw/naz/pull/113 + - `naz` now has a configurable timeout when trying to connect to SMSC + - `naz` will now be able to detect when the connection to SMSC is disconnected and will attempt to re-connect & re-bind + - bugfix; `asyncio.streams.StreamWriter.drain` should not be called concurrently by multiple coroutines + - when shutting down, `naz` now tries to make sure that write buffers are properly flushed. ## **version:** v0.6.0-beta.1 - Bug fix: https://github.com/komuw/naz/pull/98 diff --git a/Makefile b/Makefile index be0a69b7..91fa818f 100644 --- a/Makefile +++ b/Makefile @@ -30,9 +30,9 @@ test: @export PYTHONASYNCIODEBUG='2' @printf "\n removing pyc files::\n" && find . -name '*.pyc' -delete;find . -name '__pycache__' -delete | xargs echo @printf "\n coverage erase::\n" && coverage erase - @printf "\n coverage run::\n" && coverage run --omit="*tests*,*cli/test_*,*examples/*,*.virtualenvs/*,*virtualenv/*,*.venv/*,*__init__*" -m unittest discover -v -s . - @printf "\n coverage report::\n" && coverage report --show-missing --fail-under=83 - @printf "\n coverage report html::\n" && coverage html --fail-under=83 --title=naz_coverage + @printf "\n coverage run::\n" && coverage run --omit="*tests*,*examples/*,*.virtualenvs/*,*virtualenv/*,*.venv/*,*__init__*" -m unittest discover -v -s . + @printf "\n coverage report::\n" && coverage report --show-missing --fail-under=84 + @printf "\n coverage report html::\n" && coverage html --fail-under=84 --title=naz_coverage @printf "\n run flake8::\n" && flake8 . @printf "\n run pylint::\n" && pylint --enable=E --disable=W,R,C examples/ naz/ tests/ cli/ documentation/ @printf "\n run bandit::\n" && bandit -r --exclude .venv -ll . diff --git a/README.md b/README.md index 9409f06b..01bba5ca 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ loop.run_until_complete(cli.tranceiver_bind()) try: # read any data from SMSC, send any queued messages to SMSC and continually check the state of the SMSC - tasks = asyncio.gather(cli.send_forever(), cli.receive_data(), cli.enquire_link()) + tasks = asyncio.gather(cli.dequeue_messages(), cli.receive_data(), cli.enquire_link()) loop.run_until_complete(tasks) loop.run_forever() except Exception as e: @@ -138,7 +138,7 @@ run: {'event': 'naz.SimpleHook.request', 'stage': 'start', 'smpp_command': 'bind_transceiver', 'log_id': None, 'environment': 'production', 'release': 'canary', 'smsc_host': '127.0.0.1', 'system_id': 'smppclient1', 'client_id': '2VU55VT86KHWXTW7X'} {'event': 'naz.Client.send_data', 'stage': 'end', 'smpp_command': 'bind_transceiver', 'log_id': None, 'msg': 'hello', 'environment': 'production', 'release': 'canary', 'smsc_host': '127.0.0.1', 'system_id': 'smppclient1', 'client_id': '2VU55VT86KHWXTW7X'} {'event': 'naz.Client.tranceiver_bind', 'stage': 'end', 'environment': 'production', 'release': 'canary', 'smsc_host': '127.0.0.1', 'system_id': 'smppclient1', 'client_id': '2VU55VT86KHWXTW7X'} -{'event': 'naz.Client.send_forever', 'stage': 'start', 'environment': 'production', 'release': 'canary', 'smsc_host': '127.0.0.1', 'system_id': 'smppclient1', 'client_id': '2VU55VT86KHWXTW7X'} +{'event': 'naz.Client.dequeue_messages', 'stage': 'start', 'environment': 'production', 'release': 'canary', 'smsc_host': '127.0.0.1', 'system_id': 'smppclient1', 'client_id': '2VU55VT86KHWXTW7X'} ``` **NB:** @@ -331,7 +331,7 @@ loop.run_until_complete(cli.tranceiver_bind()) try: # read any data from SMSC, send any queued messages to SMSC and continually check the state of the SMSC - tasks = asyncio.gather(cli.send_forever(), cli.receive_data(), cli.enquire_link()) + tasks = asyncio.gather(cli.dequeue_messages(), cli.receive_data(), cli.enquire_link()) loop.run_until_complete(tasks) loop.run_forever() except Exception as e: @@ -398,7 +398,7 @@ reader, writer = loop.run_until_complete(cli.connect()) loop.run_until_complete(cli.tranceiver_bind()) try: # read any data from SMSC, send any queued messages to SMSC and continually check the state of the SMSC - tasks = asyncio.gather(cli.send_forever(), cli.receive_data(), cli.enquire_link()) + tasks = asyncio.gather(cli.dequeue_messages(), cli.receive_data(), cli.enquire_link()) loop.run_until_complete(tasks) loop.run_forever() except Exception as e: diff --git a/cli/cli.py b/cli/cli.py index 4077be51..1e770e6c 100644 --- a/cli/cli.py +++ b/cli/cli.py @@ -212,21 +212,10 @@ def main(): return # call naz api - cli = naz.Client(**kwargs) - # connect to the SMSC host - _, _ = loop.run_until_complete(cli.connect()) - # bind to SMSC as a tranceiver - loop.run_until_complete(cli.tranceiver_bind()) - - # read any data from SMSC, send any queued messages to SMSC and continually check the state of the SMSC - tasks = asyncio.gather( - cli.send_forever(TESTING=dry_run), - cli.receive_data(TESTING=dry_run), - cli.enquire_link(TESTING=dry_run), - sig._signal_handling(logger=logger, client=cli, loop=loop), - loop=loop, + client = naz.Client(**kwargs) + loop.run_until_complete( + async_main(client=client, logger=logger, loop=loop, dry_run=dry_run) ) - loop.run_until_complete(tasks) except Exception as e: logger.log(logging.ERROR, {"event": "naz.cli.main", "stage": "end", "error": str(e)}) sys.exit(77) @@ -234,5 +223,26 @@ def main(): logger.log(logging.INFO, {"event": "naz.cli.main", "stage": "end"}) +async def async_main( + client: naz.Client, + logger: naz.logger.SimpleLogger, + loop: asyncio.events.AbstractEventLoop, + dry_run: bool, +): + # connect & bind to the SMSC host + await client.connect() + await client.tranceiver_bind() + + # send any queued messages to SMSC, read any data from SMSC and continually check the state of the SMSC + tasks = asyncio.gather( + client.dequeue_messages(TESTING=dry_run), + client.receive_data(TESTING=dry_run), + client.enquire_link(TESTING=dry_run), + sig._signal_handling(logger=logger, client=client, loop=loop), + loop=loop, + ) + await tasks + + if __name__ == "__main__": main() diff --git a/documentation/config.md b/documentation/config.md index 27e5b523..fc09b6c5 100644 --- a/documentation/config.md +++ b/documentation/config.md @@ -53,12 +53,13 @@ loglevel | the level at which to log | INFO log_metadata | metadata that will be included in all log statements | {"smsc_host": smsc_host, "system_id": system_id} codec_class | python class instance to be used to encode/decode messages | naz.nazcodec.SimpleNazCodec codec_errors_level | same meaning as the `errors` argument to pythons' `encode` method as [defined here](https://docs.python.org/3/library/codecs.html#codecs.encode) | strict -enquire_link_interval | time in seconds to wait before sending an `enquire_link` request to SMSC to check on its status | 90 +enquire_link_interval | time in seconds to wait before sending an `enquire_link` request to SMSC to check on its status | 55.0 rateLimiter | python class instance implementing rate limitation | naz.ratelimiter.SimpleRateLimiter hook | python class instance implemeting functionality/hooks to be called by `naz` just before sending request to SMSC and just after getting response from SMSC | naz.hooks.SimpleHook throttle_handler | python class instance implementing functionality of what todo when naz starts getting throttled responses from SMSC | naz.throttle.SimpleThrottleHandler correlation_handler | A python class instance that naz uses to store relations between SMPP sequence numbers and user applications' log_id's and/or hook_metadata. | naz.correlater.SimpleCorrelater -drain_duration | duration in seconds that `naz` will wait for after receiving a termination signal. | 8.00 +drain_duration | duration in seconds that `naz` will wait for after receiving a termination signal. | 8.00 +connect_timeout | duration that `naz` will try to connect to SMSC before timing out | 30.00 `SMSC`: Short Message Service Centre, ie the server `ESME`: External Short Message Entity, ie the client @@ -84,7 +85,7 @@ drain_duration | duration in seconds that `naz` will wait for after receiving a "release": "canary" }, "codec_errors_level": "ignore", - "enquire_link_interval": 30, + "enquire_link_interval": 30.0, "rateLimiter": "dotted.path.to.CustomRateLimiter" } ``` diff --git a/documentation/sphinx-docs/introduction.rst b/documentation/sphinx-docs/introduction.rst index befdfc8f..da4f2164 100644 --- a/documentation/sphinx-docs/introduction.rst +++ b/documentation/sphinx-docs/introduction.rst @@ -63,7 +63,7 @@ naz is in active development and it's API may change in backward incompatible wa try: # read any data from SMSC, send any queued messages to SMSC and continually check the state of the SMSC - tasks = asyncio.gather(cli.send_forever(), cli.receive_data(), cli.enquire_link()) + tasks = asyncio.gather(cli.dequeue_messages(), cli.receive_data(), cli.enquire_link()) loop.run_until_complete(tasks) loop.run_forever() except Exception as e: @@ -373,7 +373,7 @@ An example of using that queue; try: # read any data from SMSC, send any queued messages to SMSC and continually check the state of the SMSC - tasks = asyncio.gather(cli.send_forever(), cli.receive_data(), cli.enquire_link()) + tasks = asyncio.gather(cli.dequeue_messages(), cli.receive_data(), cli.enquire_link()) loop.run_until_complete(tasks) loop.run_forever() except Exception as e: diff --git a/examples/example_config.json b/examples/example_config.json index aff7be4f..2e461598 100644 --- a/examples/example_config.json +++ b/examples/example_config.json @@ -12,6 +12,6 @@ "release": "canary" }, "codec_errors_level": "ignore", - "enquire_link_interval": 70, + "enquire_link_interval": 70.00, "rateLimiter": "examples.example_klasses.ExampleRateLimiter" } \ No newline at end of file diff --git a/examples/in_mem_queue_example.py b/examples/in_mem_queue_example.py index 2f33a53d..c549fc6b 100644 --- a/examples/in_mem_queue_example.py +++ b/examples/in_mem_queue_example.py @@ -43,7 +43,7 @@ try: # read any data from SMSC, send any queued messages to SMSC and continually check the state of the SMSC - tasks = asyncio.gather(cli.send_forever(), cli.receive_data(), cli.enquire_link()) + tasks = asyncio.gather(cli.dequeue_messages(), cli.receive_data(), cli.enquire_link()) loop.run_until_complete(tasks) loop.run_forever() except Exception as e: diff --git a/examples/rabbitmq_queue_example.py b/examples/rabbitmq_queue_example.py index 7ccf407b..ed11f9c3 100644 --- a/examples/rabbitmq_queue_example.py +++ b/examples/rabbitmq_queue_example.py @@ -128,7 +128,7 @@ def blocking_dequeue(self): system_id="smppclient1", password=os.getenv("password", "password"), outboundqueue=outboundqueue, - enquire_link_interval=17, + enquire_link_interval=17.00, ) item_to_enqueue = { @@ -148,7 +148,7 @@ def blocking_dequeue(self): try: # read any data from SMSC, send any queued messages to SMSC and continually check the state of the SMSC - tasks = asyncio.gather(cli.send_forever(), cli.receive_data(), cli.enquire_link()) + tasks = asyncio.gather(cli.dequeue_messages(), cli.receive_data(), cli.enquire_link()) loop.run_until_complete(tasks) loop.run_forever() except Exception as e: diff --git a/examples/redis_queue_example.py b/examples/redis_queue_example.py index 2bd6a599..bf66bee0 100644 --- a/examples/redis_queue_example.py +++ b/examples/redis_queue_example.py @@ -92,7 +92,7 @@ def blocking_dequeue(self): try: # read any data from SMSC, send any queued messages to SMSC and continually check the state of the SMSC - tasks = asyncio.gather(cli.send_forever(), cli.receive_data(), cli.enquire_link()) + tasks = asyncio.gather(cli.dequeue_messages(), cli.receive_data(), cli.enquire_link()) loop.run_until_complete(tasks) loop.run_forever() except Exception as e: diff --git a/naz/client.py b/naz/client.py index ce5088ed..e432de18 100644 --- a/naz/client.py +++ b/naz/client.py @@ -47,7 +47,7 @@ class Client: # read any data from SMSC, send any queued messages to SMSC and continually check the state of the SMSC tasks = asyncio.gather( - client.send_forever(), + client.dequeue_messages(), client.receive_data(), client.enquire_link(), loop=loop, @@ -90,7 +90,7 @@ def __init__( registered_delivery: int = 0b00000001, # see section 5.2.17 replace_if_present_flag: int = 0x00000000, sm_default_msg_id: int = 0x00000000, - enquire_link_interval: int = 300, + enquire_link_interval: float = 55.00, log_handler: typing.Union[None, logger.BaseLogger] = None, loglevel: str = "INFO", log_metadata: typing.Union[None, dict] = None, @@ -102,6 +102,9 @@ def __init__( throttle_handler: typing.Union[None, throttle.BaseThrottleHandler] = None, correlation_handler: typing.Union[None, correlater.BaseCorrelater] = None, drain_duration: float = 8.00, + # connect_timeout value inspired by vumi + # https://github.com/praekeltfoundation/vumi/blob/02518583774bcb4db5472aead02df617e1725997/vumi/transports/smpp/config.py#L124 + connect_timeout: float = 30.0, ) -> None: """ Parameters: @@ -184,6 +187,7 @@ def __init__( throttle_handler=throttle_handler, correlation_handler=correlation_handler, drain_duration=drain_duration, + connect_timeout=connect_timeout, ) self._PID = os.getpid() @@ -222,7 +226,7 @@ def __init__( "smsc_host": self.smsc_host, "system_id": system_id, "client_id": self.client_id, - "process_id": self._PID, + "pid": self._PID, } ) @@ -264,8 +268,8 @@ def __init__( self.data_coding = self._find_data_coding(self.encoding) - self.reader: typing.Any = None - self.writer: typing.Any = None + self.reader: typing.Union[None, asyncio.streams.StreamReader] = None + self.writer: typing.Union[None, asyncio.streams.StreamWriter] = None if log_handler is not None: self.logger = log_handler @@ -306,7 +310,9 @@ def __init__( self.current_session_state = SmppSessionState.CLOSED self.drain_duration = drain_duration + self.connect_timeout = connect_timeout self.SHOULD_SHUT_DOWN: bool = False + self.drain_lock: asyncio.Lock = asyncio.Lock() @staticmethod def _validate_client_args( @@ -335,7 +341,7 @@ def _validate_client_args( registered_delivery: int, replace_if_present_flag: int, sm_default_msg_id: int, - enquire_link_interval: int, + enquire_link_interval: float, log_handler: typing.Union[None, logger.BaseLogger], loglevel: str, log_metadata: typing.Union[None, dict], @@ -347,213 +353,309 @@ def _validate_client_args( throttle_handler: typing.Union[None, throttle.BaseThrottleHandler], correlation_handler: typing.Union[None, correlater.BaseCorrelater], drain_duration: float, + connect_timeout: float, ) -> None: + """ + Checks that the arguments to `naz.Client` are okay. + It raises an Exception that comprises of a list of Exceptions + """ + errors: typing.List[ValueError] = [] if not isinstance(smsc_host, str): - raise ValueError( - "`smsc_host` should be of type:: `str` You entered: {0}".format(type(smsc_host)) + errors.append( + ValueError( + "`smsc_host` should be of type:: `str` You entered: {0}".format(type(smsc_host)) + ) ) if not isinstance(smsc_port, int): - raise ValueError( - "`smsc_port` should be of type:: `int` You entered: {0}".format(type(smsc_port)) + errors.append( + ValueError( + "`smsc_port` should be of type:: `int` You entered: {0}".format(type(smsc_port)) + ) ) if not isinstance(system_id, str): - raise ValueError( - "`system_id` should be of type:: `str` You entered: {0}".format(type(system_id)) + errors.append( + ValueError( + "`system_id` should be of type:: `str` You entered: {0}".format(type(system_id)) + ) ) if not isinstance(password, str): - raise ValueError( - "`password` should be of type:: `str` You entered: {0}".format(type(password)) + errors.append( + ValueError( + "`password` should be of type:: `str` You entered: {0}".format(type(password)) + ) ) if not isinstance(outboundqueue, q.BaseOutboundQueue): - raise ValueError( - "`outboundqueue` should be of type:: `naz.q.BaseOutboundQueue` You entered: {0}".format( - type(outboundqueue) + errors.append( + ValueError( + "`outboundqueue` should be of type:: `naz.q.BaseOutboundQueue` You entered: {0}".format( + type(outboundqueue) + ) ) ) if not isinstance(client_id, (type(None), str)): - raise ValueError( - "`client_id` should be of type:: `None` or `str` You entered: {0}".format( - type(client_id) + errors.append( + ValueError( + "`client_id` should be of type:: `None` or `str` You entered: {0}".format( + type(client_id) + ) ) ) if not isinstance(system_type, str): - raise ValueError( - "`system_type` should be of type:: `str` You entered: {0}".format(type(system_type)) + errors.append( + ValueError( + "`system_type` should be of type:: `str` You entered: {0}".format( + type(system_type) + ) + ) ) if not isinstance(addr_ton, int): - raise ValueError( - "`addr_ton` should be of type:: `int` You entered: {0}".format(type(addr_ton)) + errors.append( + ValueError( + "`addr_ton` should be of type:: `int` You entered: {0}".format(type(addr_ton)) + ) ) if not isinstance(addr_npi, int): - raise ValueError( - "`addr_npi` should be of type:: `int` You entered: {0}".format(type(addr_npi)) + errors.append( + ValueError( + "`addr_npi` should be of type:: `int` You entered: {0}".format(type(addr_npi)) + ) ) if not isinstance(address_range, str): - raise ValueError( - "`address_range` should be of type:: `str` You entered: {0}".format( - type(address_range) + errors.append( + ValueError( + "`address_range` should be of type:: `str` You entered: {0}".format( + type(address_range) + ) ) ) if not isinstance(encoding, str): - raise ValueError( - "`encoding` should be of type:: `str` You entered: {0}".format(type(encoding)) + errors.append( + ValueError( + "`encoding` should be of type:: `str` You entered: {0}".format(type(encoding)) + ) ) if not isinstance(interface_version, int): - raise ValueError( - "`interface_version` should be of type:: `int` You entered: {0}".format( - type(interface_version) + errors.append( + ValueError( + "`interface_version` should be of type:: `int` You entered: {0}".format( + type(interface_version) + ) ) ) if not isinstance(service_type, str): - raise ValueError( - "`service_type` should be of type:: `str` You entered: {0}".format( - type(service_type) + errors.append( + ValueError( + "`service_type` should be of type:: `str` You entered: {0}".format( + type(service_type) + ) ) ) if not isinstance(source_addr_ton, int): - raise ValueError( - "`source_addr_ton` should be of type:: `int` You entered: {0}".format( - type(source_addr_ton) + errors.append( + ValueError( + "`source_addr_ton` should be of type:: `int` You entered: {0}".format( + type(source_addr_ton) + ) ) ) if not isinstance(source_addr_npi, int): - raise ValueError( - "`source_addr_npi` should be of type:: `int` You entered: {0}".format( - type(source_addr_npi) + errors.append( + ValueError( + "`source_addr_npi` should be of type:: `int` You entered: {0}".format( + type(source_addr_npi) + ) ) ) if not isinstance(dest_addr_ton, int): - raise ValueError( - "`dest_addr_ton` should be of type:: `int` You entered: {0}".format( - type(dest_addr_ton) + errors.append( + ValueError( + "`dest_addr_ton` should be of type:: `int` You entered: {0}".format( + type(dest_addr_ton) + ) ) ) if not isinstance(dest_addr_npi, int): - raise ValueError( - "`dest_addr_npi` should be of type:: `int` You entered: {0}".format( - type(dest_addr_npi) + errors.append( + ValueError( + "`dest_addr_npi` should be of type:: `int` You entered: {0}".format( + type(dest_addr_npi) + ) ) ) if not isinstance(esm_class, int): - raise ValueError( - "`esm_class` should be of type:: `int` You entered: {0}".format(type(esm_class)) + errors.append( + ValueError( + "`esm_class` should be of type:: `int` You entered: {0}".format(type(esm_class)) + ) ) if not isinstance(protocol_id, int): - raise ValueError( - "`protocol_id` should be of type:: `int` You entered: {0}".format(type(protocol_id)) + errors.append( + ValueError( + "`protocol_id` should be of type:: `int` You entered: {0}".format( + type(protocol_id) + ) + ) ) if not isinstance(priority_flag, int): - raise ValueError( - "`priority_flag` should be of type:: `int` You entered: {0}".format( - type(priority_flag) + errors.append( + ValueError( + "`priority_flag` should be of type:: `int` You entered: {0}".format( + type(priority_flag) + ) ) ) if not isinstance(schedule_delivery_time, str): - raise ValueError( - "`schedule_delivery_time` should be of type:: `str` You entered: {0}".format( - type(schedule_delivery_time) + errors.append( + ValueError( + "`schedule_delivery_time` should be of type:: `str` You entered: {0}".format( + type(schedule_delivery_time) + ) ) ) if not isinstance(validity_period, str): - raise ValueError( - "`validity_period` should be of type:: `str` You entered: {0}".format( - type(validity_period) + errors.append( + ValueError( + "`validity_period` should be of type:: `str` You entered: {0}".format( + type(validity_period) + ) ) ) if not isinstance(registered_delivery, int): - raise ValueError( - "`registered_delivery` should be of type:: `int` You entered: {0}".format( - type(registered_delivery) + errors.append( + ValueError( + "`registered_delivery` should be of type:: `int` You entered: {0}".format( + type(registered_delivery) + ) ) ) if not isinstance(replace_if_present_flag, int): - raise ValueError( - "`replace_if_present_flag` should be of type:: `int` You entered: {0}".format( - type(replace_if_present_flag) + errors.append( + ValueError( + "`replace_if_present_flag` should be of type:: `int` You entered: {0}".format( + type(replace_if_present_flag) + ) ) ) if not isinstance(sm_default_msg_id, int): - raise ValueError( - "`sm_default_msg_id` should be of type:: `int` You entered: {0}".format( - type(sm_default_msg_id) + errors.append( + ValueError( + "`sm_default_msg_id` should be of type:: `int` You entered: {0}".format( + type(sm_default_msg_id) + ) ) ) - if not isinstance(enquire_link_interval, int): - raise ValueError( - "`enquire_link_interval` should be of type:: `int` You entered: {0}".format( - type(enquire_link_interval) + if not isinstance(enquire_link_interval, float): + errors.append( + ValueError( + "`enquire_link_interval` should be of type:: `float` You entered: {0}".format( + type(enquire_link_interval) + ) ) ) if not isinstance(log_handler, (type(None), logger.BaseLogger)): - raise ValueError( - "`log_handler` should be of type:: `None` or `naz.logger.BaseLogger` You entered: {0}".format( - type(log_handler) + errors.append( + ValueError( + "`log_handler` should be of type:: `None` or `naz.logger.BaseLogger` You entered: {0}".format( + type(log_handler) + ) ) ) if not isinstance(loglevel, str): - raise ValueError( - "`loglevel` should be of type:: `str` You entered: {0}".format(type(loglevel)) + errors.append( + ValueError( + "`loglevel` should be of type:: `str` You entered: {0}".format(type(loglevel)) + ) ) if loglevel.upper() not in ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: - raise ValueError( - """`loglevel` should be one of; 'NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR' or 'CRITICAL'. You entered: {0}""".format( - loglevel + errors.append( + ValueError( + """`loglevel` should be one of; 'NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR' or 'CRITICAL'. You entered: {0}""".format( + loglevel + ) ) ) if not isinstance(log_metadata, (type(None), dict)): - raise ValueError( - "`log_metadata` should be of type:: `None` or `dict` You entered: {0}".format( - type(log_metadata) + errors.append( + ValueError( + "`log_metadata` should be of type:: `None` or `dict` You entered: {0}".format( + type(log_metadata) + ) ) ) if not isinstance(codec_class, (type(None), nazcodec.BaseNazCodec)): - raise ValueError( - "`codec_class` should be of type:: `None` or `naz.nazcodec.BaseNazCodec` You entered: {0}".format( - type(codec_class) + errors.append( + ValueError( + "`codec_class` should be of type:: `None` or `naz.nazcodec.BaseNazCodec` You entered: {0}".format( + type(codec_class) + ) ) ) if not isinstance(codec_errors_level, str): - raise ValueError( - "`codec_errors_level` should be of type:: `str` You entered: {0}".format( - type(codec_errors_level) + errors.append( + ValueError( + "`codec_errors_level` should be of type:: `str` You entered: {0}".format( + type(codec_errors_level) + ) ) ) if not isinstance(rateLimiter, (type(None), ratelimiter.BaseRateLimiter)): - raise ValueError( - "`rateLimiter` should be of type:: `None` or `naz.ratelimiter.BaseRateLimiter` You entered: {0}".format( - type(rateLimiter) + errors.append( + ValueError( + "`rateLimiter` should be of type:: `None` or `naz.ratelimiter.BaseRateLimiter` You entered: {0}".format( + type(rateLimiter) + ) ) ) if not isinstance(hook, (type(None), hooks.BaseHook)): - raise ValueError( - "`hook` should be of type:: `None` or `naz.hooks.BaseHook` You entered: {0}".format( - type(hook) + errors.append( + ValueError( + "`hook` should be of type:: `None` or `naz.hooks.BaseHook` You entered: {0}".format( + type(hook) + ) ) ) if not isinstance(sequence_generator, (type(None), sequence.BaseSequenceGenerator)): - raise ValueError( - "`sequence_generator` should be of type:: `None` or `naz.sequence.BaseSequenceGenerator` You entered: {0}".format( - type(sequence_generator) + errors.append( + ValueError( + "`sequence_generator` should be of type:: `None` or `naz.sequence.BaseSequenceGenerator` You entered: {0}".format( + type(sequence_generator) + ) ) ) if not isinstance(throttle_handler, (type(None), throttle.BaseThrottleHandler)): - raise ValueError( - "`throttle_handler` should be of type:: `None` or `naz.throttle.BaseThrottleHandler` You entered: {0}".format( - type(throttle_handler) + errors.append( + ValueError( + "`throttle_handler` should be of type:: `None` or `naz.throttle.BaseThrottleHandler` You entered: {0}".format( + type(throttle_handler) + ) ) ) if not isinstance(correlation_handler, (type(None), correlater.BaseCorrelater)): - raise ValueError( - "`correlation_handler` should be of type:: `None` or `naz.correlater.BaseCorrelater` You entered: {0}".format( - type(correlation_handler) + errors.append( + ValueError( + "`correlation_handler` should be of type:: `None` or `naz.correlater.BaseCorrelater` You entered: {0}".format( + type(correlation_handler) + ) ) ) if not isinstance(drain_duration, float): - raise ValueError( - "`drain_duration` should be of type:: `float` You entered: {0}".format( - type(drain_duration) + errors.append( + ValueError( + "`drain_duration` should be of type:: `float` You entered: {0}".format( + type(drain_duration) + ) ) ) + if not isinstance(connect_timeout, float): + errors.append( + ValueError( + "`connect_timeout` should be of type:: `float` You entered: {0}".format( + type(connect_timeout) + ) + ) + ) + if len(errors): + raise NazClientError(errors) def _sanity_check_logger(self): """ @@ -619,7 +721,9 @@ async def connect( make a network connection to SMSC server. """ self._log(logging.INFO, {"event": "naz.Client.connect", "stage": "start"}) - reader, writer = await asyncio.open_connection(self.smsc_host, self.smsc_port) + reader, writer = await asyncio.wait_for( + asyncio.open_connection(self.smsc_host, self.smsc_port), timeout=self.connect_timeout + ) self.reader: asyncio.streams.StreamReader = reader self.writer: asyncio.streams.StreamWriter = writer self._log(logging.INFO, {"event": "naz.Client.connect", "stage": "end"}) @@ -726,12 +830,11 @@ async def enquire_link(self, TESTING: bool = False) -> typing.Union[bytes, None] Parameters: TESTING: indicates whether this method is been called while running tests. """ + # sleep during startup so that `naz` can have had time to connect & bind + await asyncio.sleep(self.enquire_link_interval) + smpp_command = SmppCommand.ENQUIRE_LINK while True: - if self.current_session_state != SmppSessionState.BOUND_TRX: - # you can only send enquire_link request when session state is BOUND_TRX - await asyncio.sleep(self.enquire_link_interval) - log_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=17)) self._log( logging.DEBUG, @@ -1203,6 +1306,59 @@ async def build_submit_sm_pdu( ) return full_pdu + async def re_establish_conn_bind( + self, smpp_command: str, log_id: str, TESTING: bool = False + ) -> None: + """ + Called if connection is lost. It reconnects & rebinds to SMSC. + + Parameters: + TESTING: indicates whether this method is been called while running tests. + """ + self._log( + logging.INFO, + { + "event": "naz.Client.re_establish_conn_bind", + "stage": "start", + "smpp_command": smpp_command, + "log_id": log_id, + "connection_lost": self.writer.transport.is_closing() if self.writer else True, + }, + ) + if self.SHOULD_SHUT_DOWN: + self._log( + logging.DEBUG, + { + "event": "naz.Client.re_establish_conn_bind", + "stage": "end", + "smpp_command": smpp_command, + "log_id": log_id, + "state": "cleanly shutting down client.", + }, + ) + return None + + try: + # 1. re-connect + # 2. re-bind + await self.connect() + await self.tranceiver_bind() + except (ConnectionError, asyncio.TimeoutError) as e: + self._log( + logging.ERROR, + { + "event": "naz.Client.re_establish_conn_bind", + "stage": "end", + "smpp_command": smpp_command, + "log_id": log_id, + "state": "unable to re-connect & re-bind to SMSC", + "error": str(e), + }, + ) + if TESTING: + # offer escape hatch for tests to come out of endless loop + return None + async def send_data( self, smpp_command: str, msg: bytes, log_id: str, hook_metadata: str = "" ) -> None: @@ -1220,6 +1376,8 @@ async def send_data( # todo: look at `set_write_buffer_limits` and `get_write_buffer_limits` methods # print("get_write_buffer_limits:", writer.transport.get_write_buffer_limits()) + if isinstance(msg, str): + msg = self.codec_class.encode(msg, self.encoding, self.codec_errors_level) log_msg = "" try: log_msg = self.codec_class.decode(msg, self.encoding, self.codec_errors_level) @@ -1236,17 +1394,18 @@ async def send_data( "smpp_command": smpp_command, "log_id": log_id, "msg": log_msg, + "connection_lost": self.writer.transport.is_closing() if self.writer else True, }, ) # check session state to see if we can send messages. # see section 2.3 of SMPP spec document v3.4 if self.current_session_state == SmppSessionState.CLOSED: - error_msg = "smpp_command: {0} cannot be sent to SMSC when the client session state is: {1}".format( + error_msg = "smpp_command `{0}` cannot be sent to SMSC when the client session state is `{1}`".format( smpp_command, self.current_session_state ) self._log( - logging.INFO, + logging.ERROR, { "event": "naz.Client.send_data", "stage": "end", @@ -1257,7 +1416,7 @@ async def send_data( "error": error_msg, }, ) - raise ValueError(error_msg) + return None elif self.current_session_state == SmppSessionState.OPEN and smpp_command not in [ "bind_transmitter", "bind_receiver", @@ -1265,11 +1424,11 @@ async def send_data( ]: # only the smpp_command's listed above are allowed by SMPP spec to be sent # if current_session_state == SmppSessionState.OPEN - error_msg = "smpp_command: {0} cannot be sent to SMSC when the client session state is: {1}".format( + error_msg = "smpp_command `{0}` cannot be sent to SMSC when the client session state is `{1}`".format( smpp_command, self.current_session_state ) self._log( - logging.INFO, + logging.ERROR, { "event": "naz.Client.send_data", "stage": "end", @@ -1280,13 +1439,18 @@ async def send_data( "error": error_msg, }, ) - raise ValueError(error_msg) + # do not raise, we do not want naz-cli to exit + return None - if isinstance(msg, str): - msg = self.codec_class.encode(msg, self.encoding, self.codec_errors_level) + if (self.current_session_state != SmppSessionState.OPEN) and ( + (self.writer is None) or self.writer.transport.is_closing() + ): + # do not re-establish connection if session state is `OPEN` + # ie we have not even connected the first time yet + await self.re_establish_conn_bind(smpp_command=smpp_command, log_id=log_id) - # call user's hook for requests try: + # call user's hook for requests await self.hook.request( smpp_command=smpp_command, log_id=log_id, hook_metadata=hook_metadata ) @@ -1303,13 +1467,33 @@ async def send_data( }, ) - # We use writer.drain() which is a flow control method that interacts with the IO write buffer. - # When the size of the buffer reaches the high watermark, - # drain blocks until the size of the buffer is drained down to the low watermark and writing can be resumed. - # When there is nothing to wait for, the drain() returns immediately. - # ref: https://docs.python.org/3/library/asyncio-stream.html#asyncio.StreamWriter.drain - self.writer.write(msg) - await self.writer.drain() + try: + if typing.TYPE_CHECKING: + # make mypy happy; https://github.com/python/mypy/issues/4805 + assert isinstance(self.writer, asyncio.streams.StreamWriter) + + # We use writer.drain() which is a flow control method that interacts with the IO write buffer. + # When the size of the buffer reaches the high watermark, + # drain blocks until the size of the buffer is drained down to the low watermark and writing can be resumed. + # When there is nothing to wait for, the drain() returns immediately. + # ref: https://docs.python.org/3/library/asyncio-stream.html#asyncio.StreamWriter.drain + self.writer.write(msg) + async with self.drain_lock: + # see: https://github.com/komuw/naz/issues/114 + await self.writer.drain() + except (ConnectionError, asyncio.TimeoutError) as e: + self._log( + logging.ERROR, + { + "event": "naz.Client.send_data", + "stage": "end", + "smpp_command": smpp_command, + "log_id": log_id, + "state": "unable to write to SMSC", + "error": str(e), + }, + ) + self._log( logging.INFO, { @@ -1321,7 +1505,7 @@ async def send_data( }, ) - async def send_forever( + async def dequeue_messages( self, TESTING: bool = False ) -> typing.Union[str, typing.Dict[typing.Any, typing.Any]]: """ @@ -1332,12 +1516,12 @@ async def send_forever( """ retry_count = 0 while True: - self._log(logging.INFO, {"event": "naz.Client.send_forever", "stage": "start"}) + self._log(logging.INFO, {"event": "naz.Client.dequeue_messages", "stage": "start"}) if self.SHOULD_SHUT_DOWN: self._log( logging.INFO, { - "event": "naz.Client.send_forever", + "event": "naz.Client.dequeue_messages", "stage": "end", "state": "cleanly shutting down client.", }, @@ -1353,9 +1537,9 @@ async def send_forever( self._log( logging.ERROR, { - "event": "naz.Client.send_forever", + "event": "naz.Client.dequeue_messages", "stage": "end", - "state": "send_forever error", + "state": "dequeue_messages error", "error": str(e), }, ) @@ -1368,9 +1552,9 @@ async def send_forever( self._log( logging.ERROR, { - "event": "naz.Client.send_forever", + "event": "naz.Client.dequeue_messages", "stage": "end", - "state": "send_forever error", + "state": "dequeue_messages error", "error": str(e), }, ) @@ -1384,9 +1568,9 @@ async def send_forever( self._log( logging.ERROR, { - "event": "naz.Client.send_forever", + "event": "naz.Client.dequeue_messages", "stage": "end", - "state": "send_forever error. sleeping for {0}minutes".format( + "state": "dequeue_messages error. sleeping for {0}minutes".format( poll_queue_interval / 60 ), "retry_count": retry_count, @@ -1395,6 +1579,9 @@ async def send_forever( ) if self.SHOULD_SHUT_DOWN: return {"shutdown": "shutdown"} + if TESTING: + # offer escape hatch for tests to come out of endless loop + return {"broker_error": "broker_error"} await asyncio.sleep(poll_queue_interval) continue # we didn't fail to dequeue a message @@ -1420,9 +1607,9 @@ async def send_forever( self._log( logging.ERROR, { - "event": "naz.Client.send_forever", + "event": "naz.Client.dequeue_messages", "stage": "end", - "state": "send_forever error", + "state": "dequeue_messages error", "error": str(e), }, ) @@ -1437,7 +1624,7 @@ async def send_forever( self._log( logging.INFO, { - "event": "naz.Client.send_forever", + "event": "naz.Client.dequeue_messages", "stage": "end", "log_id": log_id, "smpp_command": smpp_command, @@ -1452,7 +1639,7 @@ async def send_forever( self._log( logging.INFO, { - "event": "naz.Client.send_forever", + "event": "naz.Client.dequeue_messages", "stage": "end", "send_request": send_request, }, @@ -1463,9 +1650,9 @@ async def send_forever( self._log( logging.ERROR, { - "event": "naz.Client.send_forever", + "event": "naz.Client.dequeue_messages", "stage": "end", - "state": "send_forever error", + "state": "dequeue_messages error", "error": str(e), }, ) @@ -1496,8 +1683,27 @@ async def receive_data(self, TESTING: bool = False) -> typing.Union[bytes, None] ) return None - # todo: look at `pause_reading` and `resume_reading` methods - command_length_header_data = await self.reader.read(4) + command_length_header_data = b"" + try: + if typing.TYPE_CHECKING: + # make mypy happy; https://github.com/python/mypy/issues/4805 + assert isinstance(self.reader, asyncio.streams.StreamReader) + + # todo: look at `pause_reading` and `resume_reading` methods + # `client.reader` and `client.writer` should not have timeouts since they are non-blocking + # https://github.com/komuw/naz/issues/116 + command_length_header_data = await self.reader.read(4) + except (ConnectionError, asyncio.TimeoutError) as e: + self._log( + logging.ERROR, + { + "event": "naz.Client.receive_data", + "stage": "end", + "state": "unable to read from SMSC", + "error": str(e), + }, + ) + if command_length_header_data == b"": retry_count += 1 poll_read_interval = self._retry_after(retry_count) @@ -1526,7 +1732,23 @@ async def receive_data(self, TESTING: bool = False) -> typing.Union[bytes, None] chunks = [] bytes_recd = 0 while bytes_recd < MSGLEN: - chunk = await self.reader.read(min(MSGLEN - bytes_recd, 2048)) + chunk = b"" + try: + if typing.TYPE_CHECKING: + # make mypy happy; https://github.com/python/mypy/issues/4805 + assert isinstance(self.reader, asyncio.streams.StreamReader) + + chunk = await self.reader.read(min(MSGLEN - bytes_recd, 2048)) + except (ConnectionError, asyncio.TimeoutError) as e: + self._log( + logging.ERROR, + { + "event": "naz.Client.receive_data", + "stage": "end", + "state": "unable to read from SMSC", + "error": str(e), + }, + ) if chunk == b"": err = RuntimeError("socket connection broken") self._log( @@ -1949,6 +2171,15 @@ async def shutdown(self) -> None: # we need to unbind first before closing writer await self.unbind() + + if typing.TYPE_CHECKING: + # make mypy happy; https://github.com/python/mypy/issues/4805 + assert isinstance(self.writer, asyncio.streams.StreamWriter) + assert isinstance(self.writer.transport, asyncio.transports.Transport) + + # see: https://github.com/komuw/naz/issues/117 + self.writer.transport.set_write_buffer_limits(0) + await self.writer.drain() self.writer.close() # sleep so that client can: @@ -1957,3 +2188,11 @@ async def shutdown(self) -> None: # - stop sending `enquire_link` requests # - send unbind to SMSC await asyncio.sleep(self.drain_duration) # asyncio.sleep so that we do not block eventloop + + +class NazClientError(Exception): + """ + Error raised when there's an error instanciating a naz Client. + """ + + pass diff --git a/setup.py b/setup.py index ef452000..da3a148f 100644 --- a/setup.py +++ b/setup.py @@ -92,12 +92,12 @@ "pypandoc", "twine", "wheel", - "Sphinx==1.8.3", + "Sphinx==2.0.1", "sphinx-autodoc-typehints==1.6.0", - "redis==2.10.6", - "pika==0.12.0", + "redis==3.2.1", + "pika==1.0.1", ], - "test": ["flake8", "mock", "pylint", "black", "bandit", "docker==3.4.0", "mypy"], + "test": ["flake8", "pylint", "black", "bandit", "docker==4.0.1", "mypy"], }, # If there are data files included in your packages that need to be # installed, specify them here. If using Python 2.6 or less, then these diff --git a/tests/test_cli.py b/tests/test_cli.py index 0f4c9935..339866f2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,11 +1,9 @@ import os import io -import sys import copy import json import signal import asyncio -import logging import argparse from unittest import TestCase, mock @@ -14,9 +12,6 @@ import docker -logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging.WARNING) - - def AsyncMock(*args, **kwargs): """ see: https://blog.miguelgrinberg.com/post/unit-testing-asyncio-code @@ -81,7 +76,7 @@ def setUp(self): "loglevel": "INFO", "log_metadata": {"environment": "production", "release": "canary"}, "codec_errors_level": "ignore", - "enquire_link_interval": 30, + "enquire_link_interval": 30.00, "rateLimiter": "examples.example_klasses.ExampleRateLimiter", } @@ -161,8 +156,24 @@ def test_success_handle_termination_signal(self): def test_termination_call_client_shutdown(self): with mock.patch("naz.Client.unbind", new=AsyncMock()) as mock_naz_unbind: + class MockStreamWriterTransport: + @staticmethod + def set_write_buffer_limits(value): + return + class MockStreamWriter: - def close(self): + transport = MockStreamWriterTransport() + + @staticmethod + def close(): + return + + @staticmethod + def write(stuff): + return + + @staticmethod + async def drain(): return self.client.writer = MockStreamWriter() diff --git a/tests/test_client.py b/tests/test_client.py index d660fe55..8c3f6339 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,21 +2,15 @@ # see: https://python-packaging.readthedocs.io/en/latest/testing.html import os -import sys import json import struct import asyncio -import logging -from unittest import TestCase +from unittest import TestCase, mock import naz -import mock import docker -logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging.DEBUG) - - def AsyncMock(*args, **kwargs): """ see: https://blog.miguelgrinberg.com/post/unit-testing-asyncio-code @@ -117,10 +111,85 @@ def mock_create_client(): outboundqueue=self.outboundqueue, ) + self.assertRaises(naz.client.NazClientError, mock_create_client) + with self.assertRaises(naz.client.NazClientError) as raised_exception: + mock_create_client() + self.assertIsInstance(raised_exception.exception.args[0][0], ValueError) + self.assertIn( + "`log_metadata` should be of type", str(raised_exception.exception.args[0][0]) + ) + + def test_all_bad_args(self): + class DummyClientArg: + pass + + client_args = { + "loglevel": "SomeBadLogLevel", + "smsc_host": DummyClientArg, + "smsc_port": DummyClientArg, + "system_id": DummyClientArg, + "password": DummyClientArg, + "outboundqueue": DummyClientArg, + "system_type": DummyClientArg, + "interface_version": DummyClientArg, + "addr_ton": DummyClientArg, + "addr_npi": DummyClientArg, + "address_range": DummyClientArg, + "encoding": DummyClientArg, + "sequence_generator": DummyClientArg, + "log_metadata": DummyClientArg, + "codec_errors_level": DummyClientArg, + "codec_class": DummyClientArg, + "service_type": DummyClientArg, + "source_addr_ton": DummyClientArg, + "source_addr_npi": DummyClientArg, + "dest_addr_ton": DummyClientArg, + "dest_addr_npi": DummyClientArg, + "esm_class": DummyClientArg, + "protocol_id": DummyClientArg, + "priority_flag": DummyClientArg, + "schedule_delivery_time": DummyClientArg, + "validity_period": DummyClientArg, + "registered_delivery": DummyClientArg, + "replace_if_present_flag": DummyClientArg, + "sm_default_msg_id": DummyClientArg, + "enquire_link_interval": DummyClientArg, + "rateLimiter": DummyClientArg, + "hook": DummyClientArg, + "throttle_handler": DummyClientArg, + "correlation_handler": DummyClientArg, + "drain_duration": DummyClientArg, + "connect_timeout": DummyClientArg, + } + + def mock_create_client(): + naz.Client(**client_args) + + self.assertRaises(naz.client.NazClientError, mock_create_client) + with self.assertRaises(naz.client.NazClientError) as raised_exception: + mock_create_client() + for exc in raised_exception.exception.args[0]: + self.assertIsInstance(exc, ValueError) + + def test_instantiate_bad_encoding(self): + encoding = "unknownEncoding" + + def mock_create_client(): + naz.Client( + smsc_host="127.0.0.1", + smsc_port=2775, + system_id="smppclient1", + password=os.getenv("password", "password"), + encoding=encoding, + outboundqueue=self.outboundqueue, + ) + self.assertRaises(ValueError, mock_create_client) with self.assertRaises(ValueError) as raised_exception: mock_create_client() - self.assertIn("`log_metadata` should be of type", str(raised_exception.exception)) + self.assertIn( + "That encoding:{0} is not recognised.".format(encoding), str(raised_exception.exception) + ) def test_can_connect(self): reader, writer = self._run(self.cli.connect()) @@ -173,7 +242,7 @@ def test_submit_sm_sending(self): self._run(self.cli.connect()) # hack to allow sending submit_sm even when state is wrong self.cli.current_session_state = "BOUND_TRX" - self._run(self.cli.send_forever(TESTING=True)) + self._run(self.cli.dequeue_messages(TESTING=True)) self.assertTrue(mock_naz_dequeue.mock.called) @@ -301,7 +370,7 @@ def test_no_sending_if_throttler(self): for _ in range(0, int(sample_size) * 2): self._run(cli.throttle_handler.throttled()) - self._run(cli.send_forever(TESTING=True)) + self._run(cli.dequeue_messages(TESTING=True)) self.assertFalse(mock_naz_dequeue.mock.called) @@ -383,7 +452,7 @@ def test_hook_called_with_metadata(self): self._run(self.cli.connect()) # hack to allow sending submit_sm even when state is wrong self.cli.current_session_state = "BOUND_TRX" - self._run(self.cli.send_forever(TESTING=True)) + self._run(self.cli.dequeue_messages(TESTING=True)) self.assertTrue(mock_hook_request.mock.called) self.assertEqual( @@ -443,11 +512,50 @@ def test__retry_after(self): self.assertEqual(self.cli._retry_after(current_retries=7) / 60, 16) self.assertEqual(self.cli._retry_after(current_retries=5432) / 60, 16) + def test_session_state_ok(self): + """ + send a `submit_sm` request when session state is `BOUND_TRX` + """ + with mock.patch( + "naz.q.SimpleOutboundQueue.dequeue", new=AsyncMock() + ) as mock_naz_dequeue, mock.patch("asyncio.streams.StreamWriter.write") as mock_naz_writer: + log_id = "12345" + short_message = "hello smpp" + mock_naz_dequeue.mock.return_value = { + "version": "1", + "log_id": log_id, + "short_message": short_message, + "smpp_command": naz.SmppCommand.SUBMIT_SM, + "source_addr": "2547000000", + "destination_addr": "254711999999", + } + + self._run(self.cli.connect()) + self._run(self.cli.tranceiver_bind()) + self.cli.current_session_state = naz.SmppSessionState.BOUND_TRX + self._run(self.cli.dequeue_messages(TESTING=True)) + + self.assertTrue(mock_naz_writer.called) + self.assertEqual(mock_naz_writer.call_count, 2) + self.assertIn(short_message, mock_naz_writer.call_args[0][0].decode()) + + def test_broken_broker(self): + with mock.patch( + "naz.q.SimpleOutboundQueue.dequeue", new=AsyncMock() + ) as mock_naz_dequeue, mock.patch("asyncio.streams.StreamWriter.write") as mock_naz_writer: + mock_naz_dequeue.mock.side_effect = ValueError("This test broker has 99 Problems") + self._run(self.cli.connect()) + res = self._run(self.cli.dequeue_messages(TESTING=True)) + self.assertEqual(res, {"broker_error": "broker_error"}) + self.assertFalse(mock_naz_writer.called) + def test_session_state(self): """ try sending a `submit_sm` request when session state is `OPEN` """ - with mock.patch("naz.q.SimpleOutboundQueue.dequeue", new=AsyncMock()) as mock_naz_dequeue: + with mock.patch( + "naz.q.SimpleOutboundQueue.dequeue", new=AsyncMock() + ) as mock_naz_dequeue, mock.patch("asyncio.streams.StreamWriter.write") as mock_naz_writer: log_id = "12345" short_message = "hello smpp" mock_naz_dequeue.mock.return_value = { @@ -460,18 +568,16 @@ def test_session_state(self): } self._run(self.cli.connect()) - with self.assertRaises(ValueError) as raised_exception: - self._run(self.cli.send_forever(TESTING=True)) - self.assertIn( - "smpp_command: submit_sm cannot be sent to SMSC when the client session state is: OPEN", - str(raised_exception.exception), - ) + self._run(self.cli.dequeue_messages(TESTING=True)) + self.assertFalse(mock_naz_writer.called) def test_submit_with_session_state_closed(self): """ try sending a `submit_sm` request when session state is `CLOSED` """ - with mock.patch("naz.q.SimpleOutboundQueue.dequeue", new=AsyncMock()) as mock_naz_dequeue: + with mock.patch( + "naz.q.SimpleOutboundQueue.dequeue", new=AsyncMock() + ) as mock_naz_dequeue, mock.patch("asyncio.streams.StreamWriter.write") as mock_naz_writer: log_id = "12345" short_message = "hello smpp" mock_naz_dequeue.mock.return_value = { @@ -482,12 +588,8 @@ def test_submit_with_session_state_closed(self): "source_addr": "2547000000", "destination_addr": "254711999999", } - with self.assertRaises(ValueError) as raised_exception: - self._run(self.cli.send_forever(TESTING=True)) - self.assertIn( - "smpp_command: submit_sm cannot be sent to SMSC when the client session state is: CLOSED", - str(raised_exception.exception), - ) + self._run(self.cli.dequeue_messages(TESTING=True)) + self.assertFalse(mock_naz_writer.called) def test_correlater_put_called(self): with mock.patch( @@ -512,7 +614,7 @@ def test_correlater_put_called(self): self._run(self.cli.connect()) # hack to allow sending submit_sm even when state is wrong self.cli.current_session_state = "BOUND_TRX" - self._run(self.cli.send_forever(TESTING=True)) + self._run(self.cli.dequeue_messages(TESTING=True)) self.assertTrue(mock_correlater_put.mock.called) self.assertEqual(mock_correlater_put.mock.call_args[1]["log_id"], log_id) @@ -534,26 +636,6 @@ def test_correlater_get_called(self): self.assertTrue(mock_correlater_get.mock.called) self.assertTrue(mock_correlater_get.mock.call_args[1]["sequence_number"]) - def test_instantiate_bad_encoding(self): - encoding = "unknownEncoding" - - def mock_create_client(): - naz.Client( - smsc_host="127.0.0.1", - smsc_port=2775, - system_id="smppclient1", - password=os.getenv("password", "password"), - encoding=encoding, - outboundqueue=self.outboundqueue, - ) - - self.assertRaises(ValueError, mock_create_client) - with self.assertRaises(ValueError) as raised_exception: - mock_create_client() - self.assertIn( - "That encoding:{0} is not recognised.".format(encoding), str(raised_exception.exception) - ) - def test_logger_called(self): with mock.patch("naz.logger.SimpleLogger.log") as mock_logger_log: mock_logger_log.return_value = None @@ -615,7 +697,7 @@ def test_submit_sm_AND_deliver_sm_correlation(self): self._run(self.cli.connect()) # hack to allow sending submit_sm even when state is wrong self.cli.current_session_state = "BOUND_TRX" - self._run(self.cli.send_forever(TESTING=True)) + self._run(self.cli.dequeue_messages(TESTING=True)) self.assertTrue(self.cli.correlation_handler.store[mock_sequence_number]) self.assertEqual( self.cli.correlation_handler.store[mock_sequence_number]["log_id"], log_id @@ -673,3 +755,50 @@ def test_submit_sm_AND_deliver_sm_correlation(self): self.assertEqual( mock_hook_response.mock.call_args[1]["hook_metadata"], hook_metadata ) + + def test_re_establish_conn_bind(self): + """ + test that `Client.re_establish_conn_bind` calls `Client.connect` & `Client.tranceiver_bind` + """ + with mock.patch("naz.Client.connect", new=AsyncMock()) as mock_naz_connect, mock.patch( + "naz.Client.tranceiver_bind", new=AsyncMock() + ) as mock_naz_tranceiver_bind: + self._run( + self.cli.re_establish_conn_bind( + smpp_command=naz.SmppCommand.SUBMIT_SM, log_id="log_id", TESTING=True + ) + ) + self.assertTrue(mock_naz_connect.mock.called) + self.assertTrue(mock_naz_tranceiver_bind.mock.called) + + def test_send_data_under_disconnection(self): + """ + test that if sockect is disconnected, `naz` will try to re-connect & re-bind + """ + with mock.patch("naz.Client.tranceiver_bind", new=AsyncMock()) as mock_naz_tranceiver_bind: + # do not connect or bind + self.cli.current_session_state = naz.SmppSessionState.BOUND_TRX + self._run( + self.cli.send_data( + smpp_command=naz.SmppCommand.SUBMIT_SM, msg=b"someMessage", log_id="log_id" + ) + ) + self.assertTrue(mock_naz_tranceiver_bind.mock.called) + + def test_issues_67(self): + """ + test to prove we have fixed: https://github.com/komuw/naz/issues/67 + 1. start broker + 2. start naz and run a naz operation like `Client..enquire_link` + 3. kill broker + 4. run a naz operation like `Client..enquire_link` + """ + with mock.patch("naz.Client.tranceiver_bind", new=AsyncMock()) as mock_naz_tranceiver_bind: + self.cli.current_session_state = naz.SmppSessionState.BOUND_TRX + self.cli.writer = None # simulate a connection loss + self._run( + self.cli.send_data( + smpp_command=naz.SmppCommand.SUBMIT_SM, msg=b"someMessage", log_id="log_id" + ) + ) + self.assertTrue(mock_naz_tranceiver_bind.mock.called) diff --git a/tests/test_correlater.py b/tests/test_correlater.py index 77622392..c93df24f 100644 --- a/tests/test_correlater.py +++ b/tests/test_correlater.py @@ -1,20 +1,14 @@ # do not to pollute the global namespace. # see: https://python-packaging.readthedocs.io/en/latest/testing.html -import sys import json import time -import mock import asyncio -import logging -from unittest import TestCase +from unittest import TestCase, mock import naz -logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging.DEBUG) - - def AsyncMock(*args, **kwargs): """ see: https://blog.miguelgrinberg.com/post/unit-testing-asyncio-code @@ -37,14 +31,6 @@ class TestCorrelater(TestCase): """ def setUp(self): - self.logger = logging.getLogger("naz.test") - handler = logging.StreamHandler() - formatter = logging.Formatter("%(message)s") - handler.setFormatter(formatter) - if not self.logger.handlers: - self.logger.addHandler(handler) - self.logger.setLevel("DEBUG") - self.max_ttl = 0.2 # sec self.correlater = naz.correlater.SimpleCorrelater(max_ttl=self.max_ttl) @@ -235,14 +221,6 @@ class TestBenchmarkCorrelater(TestCase): """ def setUp(self): - self.logger = logging.getLogger("naz.test") - handler = logging.StreamHandler() - formatter = logging.Formatter("%(message)s") - handler.setFormatter(formatter) - if not self.logger.handlers: - self.logger.addHandler(handler) - self.logger.setLevel("DEBUG") - self.max_ttl = 0.2 # sec self.correlater = naz.correlater.SimpleCorrelater(max_ttl=self.max_ttl) diff --git a/tests/test_nazcodec.py b/tests/test_nazcodec.py index 77a51ec1..957b784c 100644 --- a/tests/test_nazcodec.py +++ b/tests/test_nazcodec.py @@ -30,14 +30,10 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -import sys -import logging from unittest import TestCase import naz -logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging.DEBUG) - class TestNazCodec(TestCase): """ diff --git a/tests/test_ratelimit.py b/tests/test_ratelimit.py index 887e1cda..7bf0b799 100644 --- a/tests/test_ratelimit.py +++ b/tests/test_ratelimit.py @@ -1,13 +1,8 @@ -import sys import time import asyncio -import logging -from unittest import TestCase +from unittest import TestCase, mock import naz -import mock - -logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging.DEBUG) def AsyncMock(*args, **kwargs): diff --git a/tests/test_throttle.py b/tests/test_throttle.py index 6417255a..779e5dba 100644 --- a/tests/test_throttle.py +++ b/tests/test_throttle.py @@ -1,17 +1,12 @@ # do not to pollute the global namespace. # see: https://python-packaging.readthedocs.io/en/latest/testing.html -import sys import asyncio -import logging from unittest import TestCase import naz -logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging.DEBUG) - - class TestThrottle(TestCase): """ run tests as: @@ -22,7 +17,7 @@ class TestThrottle(TestCase): def setUp(self): self.throttle_handler = naz.throttle.SimpleThrottleHandler( - sampling_period=10.00, sample_size=12.00, deny_request_at=1.00 + sampling_period=0.50, sample_size=12.00, deny_request_at=1.00 ) def tearDown(self):