diff --git a/.coverage b/.coverage new file mode 100644 index 0000000..be628f2 Binary files /dev/null and b/.coverage differ diff --git a/coverage.ini b/coverage.ini new file mode 100644 index 0000000..9dd8a67 --- /dev/null +++ b/coverage.ini @@ -0,0 +1,17 @@ +[run] +dynamic_context = test_function +omit = + */shibokensupport/* + */signature_bootstrap.py +branch = True + +[report] +skip_empty = True +exclude_lines = + pragma: no cover + raise NotImplementedError + if typing.TYPE_CHECKING: + +[html] +show_contexts = True +title = UDS Client Test Coverage Report diff --git a/src/UDSClient.py b/src/UDSClient.py index 85f4aea..bdba7de 100755 --- a/src/UDSClient.py +++ b/src/UDSClient.py @@ -43,7 +43,7 @@ import typing from uds.ui import QtCore, QtWidgets, QtGui, QSettings, Ui_MainWindow # pyright: ignore -from uds.rest import RestApi, RetryException, InvalidVersion +from uds.rest import RestApi # Just to ensure there are available on runtime from uds.tunnel import forward as tunnel_forwards # pyright: ignore[reportUnusedImport] @@ -138,7 +138,7 @@ def stop_animation(self) -> None: def fetch_version(self) -> None: try: self.api.get_version() - except InvalidVersion as e: + except exceptions.InvalidVersion as e: QtWidgets.QMessageBox.critical( self, 'Upgrade required', @@ -174,7 +174,7 @@ def fetch_transport_data(self) -> None: # Execute the waiting tasks... threading.Thread(target=end_script).start() - except RetryException as e: + except exceptions.RetryException as e: self.ui.info.setText(str(e) + ', retrying access...') # Retry operation in ten seconds QtCore.QTimer.singleShot(10000, self.fetch_transport_data) @@ -251,7 +251,7 @@ def verify_host_approval(hostName: str) -> bool: return approved -def ssl_error_processor(hostname: str, serial: str) -> bool: +def ssl_certificate_validator(hostname: str, serial: str) -> bool: settings = QSettings() settings.beginGroup('ssl') @@ -281,7 +281,7 @@ def minimal(api: RestApi, ticket: str, scrambler: str) -> int: logger.debug('Getting version') try: api.get_version() - except InvalidVersion as e: + except exceptions.InvalidVersion as e: QtWidgets.QMessageBox.critical( None, # type: ignore 'Upgrade required', @@ -298,7 +298,7 @@ def minimal(api: RestApi, ticket: str, scrambler: str) -> int: # Execute the waiting task... threading.Thread(target=end_script).start() - except RetryException as e: + except exceptions.RetryException as e: QtWidgets.QMessageBox.warning( None, # type: ignore 'Service not ready', @@ -337,31 +337,31 @@ def parse_arguments(args: typing.List[str]) -> typing.Tuple[str, str, str, bool] use_minimal_interface = False uds_url = args[1] - + if uds_url == '--minimal': use_minimal_interface = True uds_url = args[2] # And get URI - + if uds_url == '--test': - raise exceptions.UDSArgumentException('test') - + raise exceptions.ArgumentException('test') + try: urlinfo = urllib.parse.urlparse(uds_url) ticket, scrambler = urlinfo.path.split('/')[1:3] except Exception: - raise exceptions.UDSMessageException('Invalid UDS URL') + raise exceptions.MessageException('Invalid UDS URL') # Check if minimal interface is requested on the URL if 'minimal' in urllib.parse.parse_qs(urlinfo.query): use_minimal_interface = True - + if urlinfo.scheme == 'uds': if not consts.DEBUG: - raise exceptions.UDSMessageException( + raise exceptions.MessageException( 'UDS Client Version {} does not support HTTP protocol Anymore.'.format(VERSION) ) elif urlinfo.scheme != 'udss': - raise exceptions.UDSMessageException('Not supported protocol') # Just shows "about" dialog + raise exceptions.MessageException('Not supported protocol') # Just shows "about" dialog return ( urlinfo.netloc, @@ -397,7 +397,7 @@ def main(args: typing.List[str]) -> int: # First parameter must be url try: host, ticket, scrambler, _use_minimal_interface = parse_arguments(args) - except exceptions.UDSMessageException as e: + except exceptions.MessageException as e: logger.debug('Detected execution without valid URI, exiting: %s', e) QtWidgets.QMessageBox.critical( None, # type: Ignore @@ -406,7 +406,7 @@ def main(args: typing.List[str]) -> int: QtWidgets.QMessageBox.StandardButton.Ok, ) return 1 - except exceptions.UDSArgumentException as e: + except exceptions.ArgumentException as e: # Currently only test, return 0 return 0 except Exception: @@ -419,15 +419,18 @@ def main(args: typing.List[str]) -> int: ) return 1 - # Setup REST api endpoint - api = RestApi(f'https://{host}/uds/rest/client', ssl_error_processor) + # Setup REST api and ssl certificate validator + api = RestApi.api( + host, + on_invalid_certificate=ssl_certificate_validator, + ) try: logger.debug('Starting execution') # Approbe before going on if verify_host_approval(host) is False: - raise Exception('Host {} was not approved'.format(host)) + raise exceptions.MessageException('Host {} was not approved'.format(host)) win = UDSClient(api, ticket, scrambler) win.show() @@ -438,10 +441,16 @@ def main(args: typing.List[str]) -> int: logger.debug('Main execution finished correctly: %s', exit_code) except Exception as e: - logger.exception('Got an exception executing client:') + if not isinstance(e, exceptions.MessageException): + logger.exception('Got an exception executing client:') + else: + logger.info('Message from error: %s', e) exit_code = 128 QtWidgets.QMessageBox.critical( - None, 'Error', f'Fatal error: {e}', QtWidgets.QMessageBox.StandardButton.Ok # type: ignore + None, + 'Error', + f'Fatal error: {e}', + QtWidgets.QMessageBox.StandardButton.Ok, ) logger.debug('Exiting') diff --git a/src/uds/exceptions.py b/src/uds/exceptions.py index f4d6d6e..ef369fb 100644 --- a/src/uds/exceptions.py +++ b/src/uds/exceptions.py @@ -1,6 +1,22 @@ +class UDSException(Exception): + pass + -class UDSMessageException(Exception): +class MessageException(UDSException): pass -class UDSArgumentException(Exception): + +class ArgumentException(UDSException): pass + + +class RetryException(UDSException): + pass + + +class InvalidVersion(UDSException): + downloadUrl: str + + def __init__(self, downloadUrl: str) -> None: + super().__init__(downloadUrl) + self.downloadUrl = downloadUrl diff --git a/src/uds/rest.py b/src/uds/rest.py index 4bc9186..bd935a7 100644 --- a/src/uds/rest.py +++ b/src/uds/rest.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright (c) 2017-2021 Virtual Cable S.L.U. +# Copyright (c) 2017-2024 Virtual Cable S.L.U. # All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, @@ -43,7 +43,7 @@ from cryptography import x509 from cryptography.hazmat.backends import default_backend -from . import consts, tools +from . import consts, tools, exceptions from .log import logger # Callback for error on cert @@ -51,21 +51,6 @@ # If returns True, ignores error CertCallbackType = typing.Callable[[str, str], bool] -# Exceptions -class UDSException(Exception): - pass - - -class RetryException(UDSException): - pass - - -class InvalidVersion(UDSException): - downloadUrl: str - - def __init__(self, downloadUrl: str) -> None: - super().__init__(downloadUrl) - self.downloadUrl = downloadUrl class RestApi: @@ -84,25 +69,20 @@ def __init__( self._on_invalid_certificate = on_invalid_certificate self._server_version = '' - def get( - self, path: str, params: typing.Optional[typing.Mapping[str, str]] = None - ) -> typing.Any: + def get(self, path: str, params: typing.Optional[typing.Mapping[str, str]] = None) -> typing.Any: if params: path += '?' + '&'.join( - '{}={}'.format(k, urllib.parse.quote(str(v).encode('utf8'))) - for k, v in params.items() + '{}={}'.format(k, urllib.parse.quote(str(v).encode('utf8'))) for k, v in params.items() ) - return json.loads( - RestApi.get_url(self._rest_api_endpoint + path, self._on_invalid_certificate) - ) + return json.loads(RestApi.get_url(self._rest_api_endpoint + path, self._on_invalid_certificate)) def process_error(self, data: typing.Any) -> None: if 'error' in data: if data.get('retryable', '0') == '1': - raise RetryException(data['error']) + raise exceptions.RetryException(data['error']) - raise UDSException(data['error']) + raise exceptions.UDSException(data['error']) def get_version(self) -> str: '''Gets and stores the serverVersion. @@ -118,17 +98,15 @@ def get_version(self) -> str: try: if self._server_version > consts.VERSION: - raise InvalidVersion(downloadUrl) + raise exceptions.InvalidVersion(downloadUrl) return self._server_version - except InvalidVersion: + except exceptions.InvalidVersion: raise except Exception as e: - raise UDSException(e) + raise exceptions.UDSException(e) from e - def get_script_and_parameters( - self, ticket: str, scrambler: str - ) -> typing.Tuple[str, typing.Any]: + def get_script_and_parameters(self, ticket: str, scrambler: str) -> typing.Tuple[str, typing.Any]: '''Gets the transport script, validates it if necesary and returns it''' try: @@ -160,18 +138,14 @@ def get_script_and_parameters( if tools.verify_signature(script, signature) is False: logger.error('Signature is invalid') - raise Exception( - 'Invalid UDS code signature. Please, report to administrator' - ) + raise Exception('Invalid UDS code signature. Please, report to administrator') return script.decode(), params # exec(script.decode("utf-8"), globals(), {'parent': self, 'sp': params}) @staticmethod - def _open( - url: str, certErrorCallback: typing.Optional[CertCallbackType] = None - ) -> typing.Any: + def _open(url: str, certErrorCallback: typing.Optional[CertCallbackType] = None) -> typing.Any: ctx = ssl.create_default_context() ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE @@ -237,9 +211,11 @@ def _open_url(url: str) -> typing.Any: return response @staticmethod - def get_url( - url: str, certErrorCallback: typing.Optional[CertCallbackType] = None - ) -> bytes: + def api(host: str, on_invalid_certificate: CertCallbackType) -> 'RestApi': + return RestApi(f'https://{host}/uds/rest/client', on_invalid_certificate) + + @staticmethod + def get_url(url: str, certErrorCallback: typing.Optional[CertCallbackType] = None) -> bytes: with RestApi._open(url, certErrorCallback) as response: resp = response.read() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_commandline.py b/tests/test_client.py similarity index 76% rename from tests/test_commandline.py rename to tests/test_client.py index f5a9c3d..f5f1159 100644 --- a/tests/test_commandline.py +++ b/tests/test_client.py @@ -33,12 +33,14 @@ from unittest import TestCase import UDSClient -from uds import exceptions, consts +from uds import exceptions, consts, rest + +from .utils import fixtures logger = logging.getLogger(__name__) -class TestTunnel(TestCase): +class TestClient(TestCase): def test_commandline(self): def _check_url(url: str, minimal: typing.Optional[str] = None, with_minimal: bool = False) -> None: host, ticket, scrambler, use_minimal = UDSClient.parse_arguments( @@ -54,16 +56,16 @@ def _check_url(url: str, minimal: typing.Optional[str] = None, with_minimal: boo UDSClient.parse_arguments(['udsclient']) # Valid command line, but not an URI. should return UDSArgumentException - with self.assertRaises(exceptions.UDSArgumentException): + with self.assertRaises(exceptions.ArgumentException): UDSClient.parse_arguments(['udsclient', '--test']) # unkonwn protocol, should return UDSArgumentException - with self.assertRaises(exceptions.UDSMessageException): + with self.assertRaises(exceptions.MessageException): UDSClient.parse_arguments(['udsclient', 'unknown://' + 'a' * 2048]) # uds protocol, but withoout debug mode, should rais exception.UDSMessagException consts.DEBUG = False - with self.assertRaises(exceptions.UDSMessageException): + with self.assertRaises(exceptions.MessageException): _check_url('uds://a/b/c') # Set DEBUG mode (on consts), now should work @@ -77,3 +79,16 @@ def _check_url(url: str, minimal: typing.Optional[str] = None, with_minimal: boo _check_url('udss://a/b/c', '--minimal', with_minimal=True) # No matter what is passed as value of minimal, if present, it will be used _check_url('udss://a/b/c?minimal=11', with_minimal=True) + + def test_rest(self): + # This is a simple test, we will test the rest api is mocked correctly + with fixtures.patch_rest_api() as api: + self.assertEqual(api.get_version(), fixtures.SERVER_VERSION) + self.assertEqual(api.get_script_and_parameters('ticket', 'scrambler'), (fixtures.SCRIPT, fixtures.PARAMETERS)) + + from_api = rest.RestApi.api('host', lambda x, y: True) + # Repeat tests, should return same results + self.assertEqual(from_api.get_version(), fixtures.SERVER_VERSION) + self.assertEqual(from_api.get_script_and_parameters('ticket', 'scrambler'), (fixtures.SCRIPT, fixtures.PARAMETERS)) + # And also, the api is the same + self.assertEqual(from_api, api) diff --git a/tests/utils/__init__,py b/tests/utils/__init__,py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/autospec.py b/tests/utils/autospec.py new file mode 100644 index 0000000..5d4b63b --- /dev/null +++ b/tests/utils/autospec.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Virtual Cable S.L.U. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" +Author: Adolfo Gómez, dkmaster at dkmon dot com +""" +import typing +import dataclasses +from unittest import mock + +@dataclasses.dataclass +class AutoSpecMethodInfo: + name: typing.Union[str, typing.Callable[..., typing.Any]] + return_value: typing.Any = None + method: 'typing.Callable[..., typing.Any]|None' = None + + +def autospec(cls: type, metods_info: typing.Iterable[AutoSpecMethodInfo], **kwargs: typing.Any) -> mock.Mock: + """ + This is a helper function that will create a mock object with the same methods as the class passed as parameter. + This is useful for testing purposes, where you want to mock a class and still have the same methods available. + + The returned value is in fact a mock object, but with the same methods as the class passed as parameter. + """ + obj = mock.create_autospec(cls, **kwargs) + for method_info in metods_info: + # Set the return value for the method or the side_effect + name = method_info.name if isinstance(method_info.name, str) else method_info.name.__name__ + mck = getattr(obj, name) + if method_info.method is not None: + mck.side_effect = method_info.method + else: + mck.return_value = method_info.return_value + + return obj \ No newline at end of file diff --git a/tests/utils/fixtures.py b/tests/utils/fixtures.py new file mode 100644 index 0000000..d5bc2e0 --- /dev/null +++ b/tests/utils/fixtures.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017-2024 Virtual Cable S.L.U. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +''' +Author: Adolfo Gómez, dkmaster at dkmon dot com +''' +import contextlib +import typing +from unittest import mock + +from uds import rest + +from . import autospec + +SERVER_VERSION: str = '4.0.0' +SCRIPT: str = ''' +# TODO: add testing script here +''' +PARAMETERS: typing.MutableMapping[str, typing.Any] = { +# TODO: add parameters here +} + +REST_METHODS_INFO: typing.List[autospec.AutoSpecMethodInfo] = [ + autospec.AutoSpecMethodInfo(rest.RestApi.get_version, return_value=SERVER_VERSION), + autospec.AutoSpecMethodInfo(rest.RestApi.get_script_and_parameters, return_value=(SCRIPT, PARAMETERS)), +] + +def create_client_mock() -> mock.Mock: + """ + Create a mock of ProxmoxClient + """ + return autospec.autospec(rest.RestApi, REST_METHODS_INFO) + + +@contextlib.contextmanager +def patch_rest_api( + **kwargs: typing.Any, +) -> typing.Generator[mock.Mock, None, None]: + client = create_client_mock() + try: + mock.patch('uds.rest.RestApi.api', return_value=client).start() + yield client + finally: + mock.patch.stopall()