diff --git a/src/UDSClient.py b/src/UDSClient.py index c3b9b22..60761cb 100755 --- a/src/UDSClient.py +++ b/src/UDSClient.py @@ -263,7 +263,7 @@ def waiting_tasks_processor() -> None: # Removing try: logger.debug('Executing threads before exit') - tools.exec_before_exit() + tools.execute_before_exit() except Exception as e: # pylint: disable=broad-exception-caught logger.debug('execBeforeExit: %s', e) @@ -379,6 +379,10 @@ def parse_arguments(args: typing.List[str]) -> typing.Tuple[str, str, str, bool] ) elif urlinfo.scheme != 'udss': raise exceptions.MessageException('Not supported protocol') # Just shows "about" dialog + + # If ticket length is not valid + if len(ticket) != consts.TICKET_LENGTH: + raise exceptions.MessageException(f'Invalid ticket: {ticket}') return ( urlinfo.netloc, diff --git a/src/uds/consts.py b/src/uds/consts.py index 53b13a9..4f8d99a 100644 --- a/src/uds/consts.py +++ b/src/uds/consts.py @@ -107,6 +107,8 @@ def _feature_requested(env_var: str) -> bool: LISTEN_ADDRESS_V6: typing.Final[str] = '::1' RESPONSE_OK: typing.Final[bytes] = b'OK' +# Ticket length +TICKET_LENGTH: typing.Final[int] = 48 # Constants strings for protocol HANDSHAKE_V1: typing.Final[bytes] = b'\x5AMGB\xA5\x01\x00' diff --git a/src/uds/tools.py b/src/uds/tools.py index e0bf1d4..82b3d2d 100644 --- a/src/uds/tools.py +++ b/src/uds/tools.py @@ -70,7 +70,7 @@ def process_iter(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: # at the same time for the same process, so no need to lock _unlink_files: typing.List[types.RemovableFile] = [] _awaitable_tasks: typing.List[types.AwaitableTask] = [] -_execBeforeExit: typing.List[typing.Callable[[], None]] = [] +_execute_before_exit: typing.List[typing.Callable[[], None]] = [] sys_fs_enc = sys.getfilesystemencoding() or 'mbcs' @@ -158,7 +158,7 @@ def unlink_files(early_stage: bool = False) -> None: logger.debug('File %s not deleted: %s', f[0], e) # Remove all processed files from list - _unlink_files[:] = list(filter(lambda x: x.early_stage != early_stage, _unlink_files)) + _unlink_files[:] = list(set(_unlink_files) - set(files_to_unlink)) def add_task_to_wait(task: typing.Any, wait_subprocesses: bool = False) -> None: @@ -197,18 +197,22 @@ def wait_for_tasks() -> None: logger.error('Waiting for tasks to finish error: %s', e) # Empty the list - _awaitable_tasks[:] = typing.cast(list[types.AwaitableTask], []) + _awaitable_tasks.clear() def register_execute_before_exit(fnc: typing.Callable[[], None]) -> None: logger.debug('Added exec before exit: %s', fnc) - _execBeforeExit.append(fnc) + _execute_before_exit.append(fnc) -def exec_before_exit() -> None: - logger.debug('Esecuting exec before exit: %s', _execBeforeExit) - for fnc in _execBeforeExit: - fnc() +def execute_before_exit() -> None: + logger.debug('Esecuting exec before exit: %s', _execute_before_exit) + for fnc in _execute_before_exit: + try: + fnc() + except Exception as e: + logger.error('Error executing before exit: %s', e) + _execute_before_exit.clear() def verify_signature(script: bytes, signature: bytes) -> bool: @@ -261,9 +265,15 @@ def get_cacerts_file() -> typing.Optional[str]: return None -def is_mac_os() -> bool: +def is_macos() -> bool: return 'darwin' in sys.platform +def is_linux() -> bool: + return 'linux' in sys.platform + +def is_windows() -> bool: + return 'win' in sys.platform + # old compat names, to ensure compatibility with old code # Basically, this will be here until v5.0. On 4.5 (or even later) Broker plugins will update diff --git a/src/uds/tunnel.py b/src/uds/tunnel.py index d5e6210..984b784 100644 --- a/src/uds/tunnel.py +++ b/src/uds/tunnel.py @@ -28,6 +28,7 @@ ''' @author: Adolfo Gómez, dkmaster at dkmon dot com ''' +import contextlib import socket import socketserver import ssl @@ -72,9 +73,10 @@ def __init__( ipv6_listen: bool = False, ipv6_remote: bool = False, ) -> None: - # Negative values for timeout, means "accept always connections" - # "but if no connection is stablished on timeout (positive)" - # "stop the listener" + # Negative values for timeout, means: + # * accept always connections, but if no connection is stablished on timeout + # (positive), stop the listener + # # Note that this is for backwards compatibility, better use "keep_listening" if timeout < 0: keep_listening = True @@ -96,21 +98,21 @@ def __init__( self.keep_listening = keep_listening self.stop_flag = threading.Event() # False initial self.current_connections = 0 - - logger.debug('Remote: %s', remote) - logger.debug('Remote IPv6: %s', self.remote_ipv6) - logger.debug('Ticket: %s', ticket) - logger.debug('Check certificate: %s', check_certificate) - logger.debug('Keep listening: %s', keep_listening) - logger.debug('Timeout: %s', timeout) self.status = types.ForwardState.TUNNEL_LISTENING self.can_stop = False timeout = timeout or 60 - self.timer = threading.Timer(timeout, ForwardServer.__checkStarted, args=(self,)) + self.timer = threading.Timer(timeout, ForwardServer._set_stoppable, args=(self,)) self.timer.start() + logger.debug('Remote: %s', remote) + logger.debug('Remote IPv6: %s', self.remote_ipv6) + logger.debug('Ticket: %s', ticket) + logger.debug('Check certificate: %s', check_certificate) + logger.debug('Keep listening: %s', keep_listening) + logger.debug('Timeout: %s', timeout) + def stop(self) -> None: if not self.stop_flag.is_set(): logger.debug('Stopping servers') @@ -120,35 +122,15 @@ def stop(self) -> None: self.timer = None self.shutdown() - def connect(self) -> ssl.SSLSocket: - with socket.socket( - socket.AF_INET6 if self.remote_ipv6 else socket.AF_INET, socket.SOCK_STREAM - ) as rsocket: - logger.info('CONNECT to %s', self.remote) - - rsocket.connect(self.remote) - - rsocket.sendall(consts.HANDSHAKE_V1) # No response expected, just the handshake - - context = ssl.create_default_context() - - # Do not "recompress" data, use only "base protocol" compression - context.options |= ssl.OP_NO_COMPRESSION - # Macs with default installed python, does not support mininum tls version set to TLSv1.3 - # USe "brew" version instead, or uncomment next line and comment the next one - # context.minimum_version = ssl.TLSVersion.TLSv1_2 if tools.isMac() else ssl.TLSVersion.TLSv1_3 - context.minimum_version = ssl.TLSVersion.TLSv1_3 - - if tools.get_cacerts_file() is not None: - context.load_verify_locations(tools.get_cacerts_file()) # Load certifi certificates - - # If ignore remote certificate - if self.check_certificate is False: - context.check_hostname = False - context.verify_mode = ssl.CERT_NONE - logger.warning('Certificate checking is disabled!') - - return context.wrap_socket(rsocket, server_hostname=self.remote[0]) + @contextlib.contextmanager + def connection(self) -> typing.Generator[ssl.SSLSocket, None, None]: + ssl_sock: typing.Optional[ssl.SSLSocket] = None + try: + ssl_sock = ForwardServer._connect(self.remote, self.remote_ipv6, self.check_certificate) + yield ssl_sock + finally: + if ssl_sock: + ssl_sock.close() def check(self) -> bool: if self.status == types.ForwardState.TUNNEL_ERROR: @@ -156,21 +138,28 @@ def check(self) -> bool: logger.debug('Checking tunnel availability') + with self.connection() as ssl_socket: + return ForwardServer._test(ssl_socket) + + @contextlib.contextmanager + def open_tunnel(self) -> typing.Generator[ssl.SSLSocket, None, None]: + self.current_connections += 1 + # Open remote connection try: - with self.connect() as ssl_socket: - ssl_socket.sendall(consts.CMD_TEST) - resp = ssl_socket.recv(2) - if resp != consts.RESPONSE_OK: - raise Exception({'Invalid tunnelresponse: {resp}'}) - logger.debug('Tunnel is available!') - return True + with self.connection() as ssl_socket: + ForwardServer._open_tunnel(ssl_socket, self.ticket) + + yield ssl_socket except ssl.SSLError as e: - logger.error(f'Certificate error connecting to {self.server_address}: {e!s}') - # will surpas the "check" method on script caller, arriving to the UDSClient error handler - raise Exception(f'Certificate error connecting to {self.server_address}') from e + logger.error(f'Certificate error connecting to {self.remote!s}: {e!s}') + self.status = types.ForwardState.TUNNEL_ERROR + self.stop() except Exception as e: - logger.error('Error connecting to tunnel server %s: %s', self.server_address, e) - return False + logger.error(f'Error connecting to {self.remote!s}: {e!s}') + self.status = types.ForwardState.TUNNEL_ERROR + self.stop() + finally: + self.current_connections -= 1 @property def stoppable(self) -> bool: @@ -178,7 +167,7 @@ def stoppable(self) -> bool: return self.can_stop @staticmethod - def __checkStarted(fs: 'ForwardServer') -> None: + def _set_stoppable(fs: 'ForwardServer') -> None: # As soon as the timer is fired, the server can be stopped # This means that: # * If not connections are stablished, the server will be stopped @@ -186,15 +175,81 @@ def __checkStarted(fs: 'ForwardServer') -> None: logger.debug('New connection limit reached') fs.timer = None fs.can_stop = True + # If timer fired, and no connections are stablished, stop the server if fs.current_connections <= 0: fs.stop() + @staticmethod + def _test(ssl_socket: ssl.SSLSocket) -> bool: + try: + ssl_socket.sendall(consts.CMD_TEST) + resp = ssl_socket.recv(2) + if resp != consts.RESPONSE_OK: + raise Exception({'Invalid tunnelresponse: {resp}'}) + logger.debug('Tunnel is available!') + return True + except ssl.SSLError as e: + logger.error(f'Certificate error connecting to {ssl_socket.getsockname()}: {e!s}') + # will surpas the "check" method on script caller, arriving to the UDSClient error handler + raise Exception(f'Certificate error connecting to {ssl_socket.getsockname()}') from e + except Exception as e: + logger.error('Error connecting to tunnel server %s: %s', ssl_socket.getsockname(), e) + return False + + @staticmethod + def _open_tunnel(ssl_socket: ssl.SSLSocket, ticket: str) -> None: + # Send handhshake + command + ticket + ssl_socket.sendall(consts.CMD_OPEN + ticket.encode()) + # Check response is OK + data = ssl_socket.recv(2) + if data != consts.RESPONSE_OK: + data += ssl_socket.recv(128) + raise Exception(f'Error received: {data.decode(errors="ignore")}') # Notify error + + @staticmethod + def _connect( + remote_addr: typing.Tuple[str, int], + use_ipv6: bool = False, + check_certificate: bool = True, + ) -> ssl.SSLSocket: + with socket.socket(socket.AF_INET6 if use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) as rsocket: + logger.info('CONNECT to %s', remote_addr) + + rsocket.connect(remote_addr) + + rsocket.sendall(consts.HANDSHAKE_V1) # No response expected, just the handshake + + # Now, upgrade to ssl + context = ssl.create_default_context() + + # Do not "recompress" data, use only "base protocol" compression + context.options |= ssl.OP_NO_COMPRESSION + # Macs with default installed python, does not support mininum tls version set to TLSv1.3 + # USe "brew" version instead, or uncomment next line and comment the next one + # context.minimum_version = ssl.TLSVersion.TLSv1_2 if tools.isMac() else ssl.TLSVersion.TLSv1_3 + # Disallow old versions of TLS + # context.minimum_version = ssl.TLSVersion.TLSv1_2 + # Secure ciphers, use this is enabled tls 1.2 + # context.set_ciphers('ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384') + + context.minimum_version = ssl.TLSVersion.TLSv1_3 + + if tools.get_cacerts_file() is not None: + context.load_verify_locations(tools.get_cacerts_file()) # Load certifi certificates + + # If ignore remote certificate + if check_certificate is False: + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + logger.warning('Certificate checking is disabled!') + + return context.wrap_socket(rsocket, server_hostname=remote_addr[0]) + class Handler(socketserver.BaseRequestHandler): # Override Base type server: ForwardServer # pyright: ignore[reportIncompatibleVariableOverride] - # server: ForwardServer def handle(self) -> None: if self.server.status == types.ForwardState.TUNNEL_LISTENING: self.server.status = types.ForwardState.TUNNEL_OPENING # Only update state on first connection @@ -206,23 +261,18 @@ def handle(self) -> None: self.request.close() # End connection without processing it return - self.server.current_connections += 1 + # Open remote connection + self.establish_and_handle_tunnel() + + # If no more connections are stablished, and server is stoppable, do it now + if self.server.current_connections <= 0 and self.server.stoppable: + self.server.stop() + def establish_and_handle_tunnel(self) -> None: # Open remote connection try: - logger.debug('Ticket %s', self.server.ticket) - with self.server.connect() as ssl_socket: - # Send handhshake + command + ticket - ssl_socket.sendall(consts.CMD_OPEN + self.server.ticket.encode()) - # Check response is OK - data = ssl_socket.recv(2) - if data != consts.RESPONSE_OK: - data += ssl_socket.recv(128) - raise Exception(f'Error received: {data.decode(errors="ignore")}') # Notify error - - # All is fine, now we can tunnel data - - self.process(remote=ssl_socket) + with self.server.open_tunnel() as ssl_socket: + self.handle_tunnel(remote=ssl_socket) except ssl.SSLError as e: logger.error(f'Certificate error connecting to {self.server.remote!s}: {e!s}') self.server.status = types.ForwardState.TUNNEL_ERROR @@ -234,23 +284,21 @@ def handle(self) -> None: finally: self.server.current_connections -= 1 - if self.server.current_connections <= 0 and self.server.stoppable: - self.server.stop() - # Processes data forwarding - def process(self, remote: ssl.SSLSocket) -> None: + def handle_tunnel(self, remote: ssl.SSLSocket) -> None: self.server.status = types.ForwardState.TUNNEL_PROCESSING logger.debug('Processing tunnel with ticket %s', self.server.ticket) # Process data until stop requested or connection closed try: while not self.server.stop_flag.is_set(): + # Wait for data from either side r, _w, _x = select.select([self.request, remote], [], [], 1.0) - if self.request in r: + if self.request in r: # If request (local) has data, send to remote data = self.request.recv(consts.BUFFER_SIZE) if not data: break remote.sendall(data) - if remote in r: + if remote in r: # If remote has data, send to request (local) data = remote.recv(consts.BUFFER_SIZE) if not data: break @@ -261,13 +309,27 @@ def process(self, remote: ssl.SSLSocket) -> None: def _run(server: ForwardServer) -> None: - logger.debug( - 'Starting forwarder: %s -> %s', - server.server_address, - server.remote, - ) - server.serve_forever() - logger.debug('Stoped forwarder %s -> %s', server.server_address, server.remote) + """ + Runs the forwarder server. + This method is intended to be run in a separate thread. + + Args: + server (ForwardServer): The forward server instance. + + Returns: + None + """ + + def _runner() -> None: + logger.debug( + 'Starting forwarder: %s -> %s', + server.server_address, + server.remote, + ) + server.serve_forever() + logger.debug('Stopped forwarder %s -> %s', server.server_address, server.remote) + + threading.Thread(target=_runner).start() def forward( @@ -277,16 +339,34 @@ def forward( local_port: int = 0, check_certificate: bool = True, keep_listening: bool = True, + use_ipv6: bool = False, ) -> ForwardServer: + """ + Forward a connection to a remote server. + + Args: + remote (Tuple[str, int]): The address and port of the remote server. + ticket (str): The ticket used for authentication. + timeout (int, optional): When the server will stop listening for new connections (default is 0, which means never). + local_port (int, optional): The local port to bind to (default is 0, which means any available port). + check_certificate (bool, optional): Whether to check the server's SSL certificate (default is True). + keep_listening (bool, optional): Whether to keep listening for new connections (default is True). + + Returns: + ForwardServer: An instance of the ForwardServer class. + + """ fs = ForwardServer( remote=remote, ticket=ticket, timeout=timeout, local_port=local_port, check_certificate=check_certificate, + ipv6_remote=use_ipv6, keep_listening=keep_listening, ) - # Starts a new thread - threading.Thread(target=_run, args=(fs,)).start() + # Starts a new thread for processing the server, + # so the main thread can continue processing other tasks + _run(fs) return fs diff --git a/tests/test_main.py b/tests/test_main.py index 7a3e65f..fd11a5d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -57,9 +57,9 @@ def _check_url(url: str, minimal: typing.Optional[str] = None, with_minimal: boo host, ticket, scrambler, use_minimal = UDSClient.parse_arguments( ['udsclient'] + ([url] if not minimal else [minimal, url]) ) - self.assertEqual(host, 'a') - self.assertEqual(ticket, 'b') - self.assertEqual(scrambler, 'c') + self.assertEqual(host, 'host') + self.assertEqual(ticket, fixtures.TICKET) + self.assertEqual(scrambler, 'scrambler') self.assertEqual(use_minimal, with_minimal) # Invalid command line, should return simeple Exception @@ -76,20 +76,22 @@ def _check_url(url: str, minimal: typing.Optional[str] = None, with_minimal: boo # uds protocol, but withoout debug mode, should rais exception.UDSMessagException consts.DEBUG = False + UDS_URL = f'uds://host/{fixtures.TICKET}/scrambler' + UDSS_URL = f'udss://host/{fixtures.TICKET}/scrambler' with self.assertRaises(exceptions.MessageException): - _check_url('uds://a/b/c') + _check_url(UDS_URL) # Set DEBUG mode (on consts), now should work consts.DEBUG = True - _check_url('uds://a/b/c') + _check_url(UDS_URL) # Now, a valid URI ssl (udss://) for debug in [True, False]: consts.DEBUG = debug - _check_url('udss://a/b/c') - _check_url('udss://a/b/c', '--minimal', with_minimal=True) + _check_url(UDSS_URL) + _check_url(UDSS_URL, '--minimal', with_minimal=True) # No matter what is passed as value of minimal, if present, it will be used - _check_url('udss://a/b/c?minimal=11', with_minimal=True) + _check_url(f'{UDSS_URL}/b/c?minimal=11', with_minimal=True) def test_rest(self) -> None: # This is a simple test, we will test the rest api is mocked correctly diff --git a/tests/test_tools.py b/tests/test_tools.py index 69706ef..5582176 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -32,11 +32,14 @@ import os import socket import sys +import typing from unittest import TestCase, mock from uds import tools, types +from .utils import fixtures + logger = logging.getLogger(__name__) @@ -104,7 +107,10 @@ def test_unlink_files(self) -> None: unlink.reset_mock() tools.unlink_files(early_stage) self.assertEqual(unlink.call_count, 10) - self.assertEqual(unlink.call_args_list, [mock.call(str(x)) for x in range(10)]) + # Assert called, but ensure order is not important + self.assertEqual( + set(i[0][0] for i in unlink.call_args_list), set(str(x) for x in range(10)) + ) # Now, tools._unlink_files should be empty self.assertEqual(len(tools._unlink_files), 0) @@ -134,10 +140,83 @@ def test_wait_tasks_finish(self) -> None: # Now, tools._awaitable_tasks should be empty self.assertEqual(len(tools._awaitable_tasks), 0) - + # And every mock should have been called twice, with no arguments join_mock.join.assert_called_with() wait_mock.wait.assert_called_with() - + # also process_iter should have been called twice, once for each type of wait with wait_subprocesses True value self.assertEqual(_process_iter.call_count, 2) + + def test_register_execute_before_exit(self) -> None: + # Just call it, nothing to test here + tools.register_execute_before_exit(mock.sentinel.first) + tools.register_execute_before_exit(mock.sentinel.second) + + self.assertIn(mock.sentinel.first, tools._execute_before_exit) + self.assertIn(mock.sentinel.second, tools._execute_before_exit) + + def test_execute_before_exit(self) -> None: + # Just call it, nothing to test here + # Mock tools.process_iter + with mock.patch('uds.tools.process_iter') as _process_iter: + for m in (mock.Mock(), mock.Mock()): + tools._execute_before_exit.append(m) + + tools.execute_before_exit() + + # Now, tools._execute_before_exit should be empty + self.assertEqual(len(tools._execute_before_exit), 0) + + # And every mock should have been called once, with no arguments + for m in tools._execute_before_exit: + typing.cast('mock.Mock', m).assert_called_with() + + def test_verify_signature(self) -> None: + # Just call it, nothing to test here + # Mock tools.process_iter + self.assertTrue(tools.verify_signature(fixtures.SIGNED_STRING.encode(), fixtures.SIGNATURE.encode())) + # Padding chars on signature are ignored, so this should also return True + self.assertTrue( + tools.verify_signature(fixtures.SIGNED_STRING.encode(), fixtures.SIGNATURE.encode() + b'xxxx') + ) + # But if the string changes, it should return False + self.assertFalse( + tools.verify_signature(fixtures.SIGNED_STRING.encode() + b'x', fixtures.SIGNATURE.encode()) + ) + # And if the signature changes, it should return False + self.assertFalse( + tools.verify_signature( + fixtures.SIGNED_STRING.encode(), fixtures.SIGNATURE.encode().replace(b'x', b'y') + ) + ) + + def test_get_cacerts_file(self) -> None: + # Just call it, nothing to test here + path = tools.get_cacerts_file() + if path is None: + self.fail('No cacerts file found') + + self.assertTrue(os.path.exists(path)) + + def test_is_macos(self) -> None: + with mock.patch('sys.platform', 'darwin'): + self.assertTrue(tools.is_macos()) + with mock.patch('sys.platform', 'linux'): + self.assertFalse(tools.is_macos()) + with mock.patch('sys.platform', 'win'): + self.assertFalse(tools.is_macos()) + + def test_compat_functions(self) -> None: + # addTaskToWait = add_task_to_wait + # saveTempFile = save_temp_file + # readTempFile = read_temp_file + # testServer = test_server + # findApp = find_application + # addFileToUnlink = register_for_delayed_deletion + self.assertEqual(tools.addTaskToWait, tools.add_task_to_wait) + self.assertEqual(tools.saveTempFile, tools.save_temp_file) + self.assertEqual(tools.readTempFile, tools.read_temp_file) + self.assertEqual(tools.testServer, tools.test_server) + self.assertEqual(tools.findApp, tools.find_application) + self.assertEqual(tools.addFileToUnlink, tools.register_for_delayed_deletion) diff --git a/tests/test_tunnel.py b/tests/test_tunnel.py new file mode 100644 index 0000000..ef44ae3 --- /dev/null +++ b/tests/test_tunnel.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Virtual Cable S.L.U. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' +Author: Adolfo Gómez, dkmaster at dkmon dot com +''' +import contextlib +import logging +import socket +import os +import ssl +import tempfile +import time +import typing +from unittest import TestCase, mock + +from uds import tunnel + +from .utils import tunnel_server, certs, fixtures + +logger = logging.getLogger(__name__) + + +class TestTunnel(TestCase): + _server: typing.Optional['tunnel_server.TunnelServer'] = None + + def setUp(self) -> None: + super().setUp() + + def tearDown(self) -> None: + super().tearDown() + if self._server and self._server.is_alive(): + self._server.join() + + @property + def server(self) -> 'tunnel_server.TunnelServer': + if self._server is None: + self._server = tunnel_server.TunnelServer() + self._server.start() + self._server.listening.wait() + return self._server + + @contextlib.contextmanager + def connect( + self, check_certificate: bool = False, use_ipv6: bool = False + ) -> typing.Iterator[ssl.SSLSocket]: + yield tunnel.ForwardServer._connect( + ('localhost', self.server.port), use_ipv6=use_ipv6, check_certificate=check_certificate + ) + + @contextlib.contextmanager + def ensure_valid_cert(self) -> typing.Iterator[None]: + """Ensure that the certificate is valid by using a temporary file with the self-signed certificate. + (Note: all self signed certificates are also valid CA certificates, so we can use it as a CA certificate file) + """ + certfile = tempfile.NamedTemporaryFile('w', delete=False) + certfile.write(certs.CERT) + certfile.close() + # mock tools.get_cacerts_file to point to certfile + try: + with mock.patch('uds.tools.get_cacerts_file', return_value=certfile.name): + yield + finally: + if os.path.exists(certfile.name): + os.unlink(certfile.name) + + def test_test_verify_cert_fails(self) -> None: + # Should raise an exception if check_certificate is True, because certificate is self-signed + with self.assertRaises(ssl.CertificateError): + with self.connect(check_certificate=True): + pass # Just to make the test run + + def test_test_verify_cert(self) -> None: + # mock toolsget_cacerts_file to point to + with self.ensure_valid_cert(): + with self.connect(check_certificate=True) as conn: + self.assertTrue(tunnel.ForwardServer._test(conn)) + + self.assertFalse(self.server.error, self.server.error_msg) + + def test_test_no_verify_cert(self) -> None: + with self.connect(check_certificate=False) as conn: + self.assertTrue(tunnel.ForwardServer._test(conn)) + + self.assertFalse(self.server.error, self.server.error_msg) + + def test_open_tunnel(self) -> None: + with self.ensure_valid_cert(): + with self.connect() as conn: + tunnel.ForwardServer._open_tunnel(conn, fixtures.TICKET) + + self.assertFalse(self.server.error, self.server.error_msg) + + def test_forward_fnc(self) -> None: + """Check that forward function works as expected + * Creates a thread that invokes tunnel._run + """ + with mock.patch('uds.tunnel._run') as run: + with mock.patch('uds.tunnel.ForwardServer') as ForwardServer: + fs = tunnel.forward(('localhost', 1234), fixtures.TICKET, 1, 1222, check_certificate=False, use_ipv6=False) + # Ensure that thread is invoked with _run as target, and fs as argument + run.assert_called_once_with(fs) + # And that ForwardServer is called with the correct parameters + ForwardServer.assert_called_once_with( + remote=('localhost', 1234), + ticket=fixtures.TICKET, + timeout=1, + local_port=1222, + check_certificate=False, + ipv6_remote=False, + keep_listening=True, + ) + + # Ensure server is stopped, we have not used it.. send some data to make it fail an stop + + def test_forward_stoppable(self) -> None: + # Patch fs._set_stoppable to check if it is called + with mock.patch('uds.tunnel.ForwardServer._set_stoppable') as _set_stoppable: + fs = tunnel.forward( + ('localhost', self.server.port), fixtures.TICKET, 1, 1222, check_certificate=False + ) + + time.sleep(1.1) # more than forward timeout (1) + self.assertTrue(_set_stoppable.called) + + fs.stop() # Ensure fs is stopped + + def test_forward_connect(self) -> None: + # Must be listening on 1222, so we can connect to it to make the tunnel start + fs = tunnel.forward(('localhost', self.server.port), fixtures.TICKET, 1, 1222, check_certificate=False) + with contextlib.closing(socket.socket()) as s: + s.connect(('localhost', 1222)) + # Do net send anything, will not be read, just to make the tunnel start + s.send(b'') + + self.assertFalse(self.server.error, self.server.error_msg) + # Ensure fs is stopped + fs.stop() diff --git a/tests/utils/certs.py b/tests/utils/certs.py new file mode 100644 index 0000000..ca73a17 --- /dev/null +++ b/tests/utils/certs.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017-2024 Virtual Cable S.L.U. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +''' +Author: Adolfo Gómez, dkmaster at dkmon dot com +''' + +import contextlib +import os +import tempfile +import typing + +import ssl + +# Self-signed certificate and key for testing purposes +# openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 36500 -nodes -subj "/C=ES/ST=Madrid/L=Madrid/O=UDS/OU=Devel/CN=localhost" +CERT: typing.Final[str] = ( + '-----BEGIN CERTIFICATE-----\n' + 'MIIFpTCCA42gAwIBAgIUTolFpGesjW2p6GCV5gTOXjkuKUIwDQYJKoZIhvcNAQEL\n' + 'BQAwYTELMAkGA1UEBhMCRVMxDzANBgNVBAgMBk1hZHJpZDEPMA0GA1UEBwwGTWFk\n' + 'cmlkMQwwCgYDVQQKDANVRFMxDjAMBgNVBAsMBURldmVsMRIwEAYDVQQDDAlsb2Nh\n' + 'bGhvc3QwIBcNMjQwMzA2MjIxMjQ3WhgPMjEyNDAyMTEyMjEyNDdaMGExCzAJBgNV\n' + 'BAYTAkVTMQ8wDQYDVQQIDAZNYWRyaWQxDzANBgNVBAcMBk1hZHJpZDEMMAoGA1UE\n' + 'CgwDVURTMQ4wDAYDVQQLDAVEZXZlbDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjAN\n' + 'BgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA1UKaOP2hetMIyCaB5dRDhPzDEwcD\n' + 'yvDnSykz2yEpERYMF8lSrhFrjSIPL6/fSY+mZI5uRwY+aIZkAcwZos0kF0PudXYQ\n' + 'LWmyFt4vxU4FechKsYlQGhn+quosT6WaJt0BXlrlO7T09r0qi/xzeeUSYUeFikrJ\n' + 'W+0F5byDgYs96OssC6yI/eKGrf7hEwG+zN04i0/+VuyUnndJH/dpuDOK49rcf9fK\n' + 'hwqChj1vkZSukRzCunkgyZ6nW+kUyXRgGx/2xaecDaC39ROAvsq8z1rJCc0caEMG\n' + 'B+kH2r1Ksql8bJDysO/K5NRPtJ5B2ByFnyoOAzOLXgBiTGpi3NVTIakENFBPkqmD\n' + 'DjPZYNcXH5LSoHZn4meR5J1+X8g3dGsWboww6nGy1ASs4ROgz/UZt+qPXDmR8i7L\n' + 'jmfNE0ca9tGyVrT5cECFzEDZzInV6eOEzsgO7iha5s7cpl2ED3h6kd+iLlXWd93T\n' + 'LnIZevVdfN6m5FvnAfnYqngNJCn4h3WOWpP+AoZ8BQTajfQtfHERD/HNtvt6bA22\n' + 'Yj12zOckXeFmhBY0e1OkatrP5vuwrDsGj/tIiGT5ElK07Bqkcqh8g0gjGsDSrC6j\n' + 'VgLmwO0BI99H+W4E/Uv4jlfnpVbzt+WpfBt1ejP/NlCTvQwA2nM3xh05JTYsgovt\n' + 'dHliaI/zrL/CxTUCAwEAAaNTMFEwHQYDVR0OBBYEFLmtufn4Jr3L/kRkNTYkaVVi\n' + 'NXU1MB8GA1UdIwQYMBaAFLmtufn4Jr3L/kRkNTYkaVViNXU1MA8GA1UdEwEB/wQF\n' + 'MAMBAf8wDQYJKoZIhvcNAQELBQADggIBACnQjdX+rqUS3ulzf36GFZ0zNwJcEZI8\n' + 'r8mumRMremwVxy+nRyyFHUc9ysuNdYgCHM0riIwujOTNL81YclJtZnBVKNcBPGnh\n' + 'WRom/94CgAMBq3DqGXXDMBulmd1QaqdOqriwIrtikRK3+yFz3GRRkF7aMGnvIer9\n' + 'DI7bQDWEj+ICcIwvFvJJIPMFARJBZVDuw/fsqeXXJt2Fl1ivGtMpf/3pASFV0WKm\n' + 'zEVst9D6WmQUIaW2oZEQHhzIq3tYbliRY0nF0YhQU5CCD1FAUZsaprBgnQhEvXAg\n' + '14snaKg2S90ESwupcPMH5r9vhCJh0d8aqQ+MbpjbvqFaLPhNsfj/WcuoGpjrxAKv\n' + 'kMPNtmhhQLUGorlw6ERkjMDQbbYz03WYpJFxOITRdzKWB5ZLXf7AiS0UDNu6D7uS\n' + 'BQff6mTv+VT+bRW5AUvLxiMMpB32LVILvpY8OlJhs63ccHKiskFuq0z7eEOAww+q\n' + 'qEPK/uyciMIR+sNTSiWi/pB3hsuv3cx33Pdtg2+KiNN0QNTienhanZ8R+WQKZsZD\n' + 'FcTLfPGFhs4edlfmG1ffbId6sxLGYVRbMJB0cfDZC8Sm1JkwTrtwFHucpOb3CQ7N\n' + 'r730XTasExkgmQ28z7u40ofEBCC59lfqZFp0CC4Ugs18vBg/L/7G5IhF+8M7huE5\n' + '6GgWRVlAbVZ3\n' + '-----END CERTIFICATE-----' +) + +KEY: typing.Final[str] = ( + '-----BEGIN PRIVATE KEY-----\n' + 'MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQDVQpo4/aF60wjI\n' + 'JoHl1EOE/MMTBwPK8OdLKTPbISkRFgwXyVKuEWuNIg8vr99Jj6Zkjm5HBj5ohmQB\n' + 'zBmizSQXQ+51dhAtabIW3i/FTgV5yEqxiVAaGf6q6ixPpZom3QFeWuU7tPT2vSqL\n' + '/HN55RJhR4WKSslb7QXlvIOBiz3o6ywLrIj94oat/uETAb7M3TiLT/5W7JSed0kf\n' + '92m4M4rj2tx/18qHCoKGPW+RlK6RHMK6eSDJnqdb6RTJdGAbH/bFp5wNoLf1E4C+\n' + 'yrzPWskJzRxoQwYH6QfavUqyqXxskPKw78rk1E+0nkHYHIWfKg4DM4teAGJMamLc\n' + '1VMhqQQ0UE+SqYMOM9lg1xcfktKgdmfiZ5HknX5fyDd0axZujDDqcbLUBKzhE6DP\n' + '9Rm36o9cOZHyLsuOZ80TRxr20bJWtPlwQIXMQNnMidXp44TOyA7uKFrmztymXYQP\n' + 'eHqR36IuVdZ33dMuchl69V183qbkW+cB+diqeA0kKfiHdY5ak/4ChnwFBNqN9C18\n' + 'cREP8c22+3psDbZiPXbM5yRd4WaEFjR7U6Rq2s/m+7CsOwaP+0iIZPkSUrTsGqRy\n' + 'qHyDSCMawNKsLqNWAubA7QEj30f5bgT9S/iOV+elVvO35al8G3V6M/82UJO9DADa\n' + 'czfGHTklNiyCi+10eWJoj/Osv8LFNQIDAQABAoICAAX95Q9iMOieh/SiVav7bBo5\n' + 'xSatCnI9P9e2GfriJ6E57q+3Eaz1AwHoHxJxQo4dNx5EJ4e0qTQ5R74KhKiKUvqZ\n' + 'sgLNhQ7W6rtcasMvD1WonFCjUa4qD3mwh/uE5PE1QcCWP905lwMHtZZRStKgmQTw\n' + 'BCnVMsBzxw1OtUiCfQQk92EslnzA9z++6tv52ekvnf3BYEgDmvleKJ55Tm3FJPXZ\n' + '7wVjLrw0k2OU06RSJRrL+rylLUKnmSl/7FGXWiE+Tf9SV5QacTtgMjx/aGaltQL4\n' + 'JwAsQeh0UlWBqVjzt3HlcKwqpey1T7f8v6TZcvenMCru28+RpdwNG9H7PGa0X6XQ\n' + '0EqOeNqdSH2KGG46itvExTVghg5h8gQaDhMVlAwNUvNeuu7cvrcMDMttFIa/2b8m\n' + 'pUA1f0o4kBMMJMm4VgJI2MkB9eMoAOvpzz9G/WTE1teTnqlaqoAaVDu/Uz2VCYbS\n' + 'PEr4LQZe2Pm0dIkYeOMAZ8lhEZTq0k539oJGGMuSsGRKDL5sv/jil/8U1bLg6PXU\n' + '1gVFIWDG+O5SKP7aURO+C0qX7/WMd+OYREpPWsvqF5EiHNgfhetxiJ3riRgUEPRo\n' + 'Y/POaDUffEA6uIen4SaqNltql5pnb4dFLpn3frQoGQkY6inYf1vtBTRSnKUEmCvl\n' + 'eVr1+WyCuOLKSRtsHDIBAoIBAQD2vMn7vw/SJZfJxcdFHb85E8eOu7XdvzzRU6bq\n' + 'f+2fRn/uOuY8vIXvI5pHBubjftv3cppDw8HPw9718ks9hyPn4uSftyHeu4nWaap/\n' + 'HpBMh04awGU5DjqncmvEUK7qYpFP9Jdk6zG+ou+uFsXFaLaQQ0WSkBlzrRXFZj/K\n' + 'mbDJq4JnuA7nXjc3g/nocpWRCQCSlWTqmLTNtIJbK5f41CNJCqLhaKcrNOLODOfK\n' + '22PHslC6OXsWH+HSFoHjKjG9Erb4Hwa5Zav8/rZlwf1SL1QcRg4RWSdjh5xlso4f\n' + '6Fo2SUEBgXiP1zD7YpKsFMCnrZmhqfg8GgQjatLUVbckrIZhAoIBAQDdRBaFnyOH\n' + 'pJAnXee9VdU97yYZbNuEL757f+npoTbFYnD0ClZgSBD6zV0W0NjyO9TrT+KQ0tJc\n' + 'ZcJ3vUP8A+fiWZP1fbANZFfTDY5ruJVMIN7sS8XNn7FLzbCOTYmR0776sp/luIxh\n' + '/WsyBIHBqBmYdbNi+p5+Rey10HCNDgM2NtYqHW4261XS+D7xPw0EGrcWPShv3OcJ\n' + '7+YEAMnlnZXBP6F/aSjXE1CHvzcMSL+igwve7xsrzMu0D/Z/BFSIgrLvw+6JoUZV\n' + 'y+dzTmQ5Mg8SBVXLKhukNfMnrej4o4Vr9V9MXFxu0c5gK+do1TdWbCYjQifJYIhp\n' + '+oalwd6ivYdVAoIBAQCHaCf47mvCSjs40j9/oMmWi1JS9JTkMtUvk5bgzoAbjtca\n' + 'aFx+LH/cM0+xdwozAyW4cL5UPhQY70dm9idwhr+fvJb3R8tgrs8ASlD1HlLWjNLC\n' + 'P5/NZg+uYU7fF+BGZP2WQYbsLV7JXiXnBjxXEBZQqXp+6nHtV6nBAVI0349zvZn9\n' + 'TbdwJfZrkxQNCwUl6SjVSQNu84sV8OAxJIVsWw9aQGoPBh3nykhGCDMU0r25lBRV\n' + 'fsIb7DdD0nJJtphBSQn8tRo9mJyAZVC4G3PoLG0ebxu9TY4eQwgDj7ALtrn7XMw+\n' + 'BU2istgAvaH8qg7odo7/d4Xxhd2Lik5VlQzDJaNBAoIBAQCO8QejtxUa8eL2q6Gk\n' + 'HSkvY6m3Ty3ZDYb+/bm9ZpqdlWTnIy598NCXVchHjxA4HRMGGYuCh8/CRTMGa8zZ\n' + 'qCRLhBcjxtjPLf3WqLFTQeGhVrLs8F6O4hWFpRHkPI8dGDAOgQrvOvPl8fMoUuUI\n' + 'mHJAnfkPflyZss6i/k9XsK++fFqKxoyHCi1dp2XyMAtWlXOl+EiBS7IuJz7vYxsL\n' + 'LWyrdVH9n4/0sdOafpsvYmf6srIeiVWCTEFkx9M0ZzW9IsI6Rtd5LijkEGAri38P\n' + 'vBkkSTINl9xXj0rQXXdd+TWectvn1tsX9I5gbryGawfe2usgaAKQA77cyC3oM4CC\n' + 'nfIpAoIBABUSCmCwyrlz6sS6L4W+fo1c74IB+0p8K78Dc0IYHOYnZ24e4P7LtzJq\n' + '17dd/mEcrlm9kV98TzRU90QcWC38K9WMq/EeIML9lqEycuC+n2p6eUtUh7tp5OL6\n' + '5DyRtQARirFz9686jY5BeCA2motJPLDQOazeMJaR+i7CfQ70npgmPzSKCI7ut29j\n' + 'Nfm78ZkpdUwDBYYCYTMwf7LrVvOnmFqmiNe5i4j0LjnBg4I79s8nIJRV/kIfc1sW\n' + 'KSvR66bSLF/i67HiZWc2dcJskh34WFdbYQ1COSMu+GL6SzN7J0fyEV8TDZ9V0NUP\n' + 'VH6zbSCVcHmtc2P8xA9VxERkeeauY8w=\n' + '-----END PRIVATE KEY-----' +) + + +@contextlib.contextmanager +def server_ssl_context() -> typing.Generator[ssl.SSLContext, None, None]: + certfile = tempfile.mktemp() + '.pem' + keyfile = tempfile.mktemp() + '.pem' + try: + with open(certfile, 'w') as f: + f.write(CERT) + with open(keyfile, 'w') as f: + f.write(KEY) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ctx.load_cert_chain(certfile=certfile, keyfile=keyfile) + yield ctx + finally: + # Remove the temporary files if they were created + for fname in (certfile, keyfile): + if os.path.exists(fname): + os.unlink(fname) diff --git a/tests/utils/fixtures.py b/tests/utils/fixtures.py index fb69e6e..a596c92 100644 --- a/tests/utils/fixtures.py +++ b/tests/utils/fixtures.py @@ -50,6 +50,27 @@ # TODO: add parameters here } +SIGNED_STRING: str = 'This is the string to test signature' +SIGNATURE: str = ( + 'rLL5ykPFWSk2L3RVlCOlCLfngEwMg6cWIZ35DrxRBjt77xvf' + 'OZYcFuDqma5zBJ43RgI4XUvLdMCeU2KZh/gHoQjIiXYBdg76' + 'B0E6h3JvE6Nl71+q9KrTlk/lp8JoroZcLCPwFzMzb/1rxCO+' + 'dplDbVyw5J08fK31oEKTIw0O1JlgkxK+zFhQalfwZvr4n0mS' + 'g9awpPAAAOYn7p9/i9As+QDden62kvU/G4iZ4w6/1YU9LmAW' + 'urxMhrIGejmaPnPzmHtovBzUxFVVr6eK9AdleDqHoxGSqqga' + 'qmiMfKXktdkSKnBfizqpCt2gshzq6QH0iwwHDlfQ+PNe1Fta' + 'xgNAFfBZjlF4Masnpn8vhvLhmZpa9oQVRkFIPvRUAKn9H7pf' + 'gUzWYlHTRpqVthZyo72B2R3bKVpdk0RF8UiQgzM8BYfMuc51' + 'HZSyy4u6P2tPqAI7IT30v2Y/s15Xa+uOKs6lP7yUlME4isDG' + 'OIkiQ3usOYK+kjdissWbjFPLQZ2sLYISe67zBAHvskHSfc5s' + '5jfKkB6+AbqzkIxHc7ZAcwBqnSOh4gMVwUfBKzLKhCTJxeEs' + 'mp85Q3z2ONmoHsDqU6KGXgW2hshA6zYB5c3hukz+zVbfaYhE' + '+EaxYpW7de4XU0EkEuNdbDC+7F3CqJ/aSoCKNVmLUW9WphMq' + 'xZORjLTCR/c=' +) + +TICKET: str = 'x' * consts.TICKET_LENGTH + def check_version() -> str: if REQUIRED_VERSION == 'fail': @@ -105,7 +126,7 @@ def patched_uds_client() -> typing.Generator['UDSClient.UDSClient', None, None]: with patch_rest_api() as client: uds_client = UDSClient.UDSClient(client, 'ticket', 'scrambler') # Now, patch object: - # - process_waiting_tasks so we do not launch any task + # - process_waiting_tasks so we do not wait for processing tasks # - error_message so we do not show any error message # - warning_message so we do not show any warning message # error_message and warning_message are static methods, so we need to patch them on the class diff --git a/tests/utils/tunnel_server.py b/tests/utils/tunnel_server.py new file mode 100644 index 0000000..cbae498 --- /dev/null +++ b/tests/utils/tunnel_server.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017-2024 Virtual Cable S.L.U. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + +''' +Author: Adolfo Gómez, dkmaster at dkmon dot com +''' +import socket +import threading +import typing + +from uds import consts + +from . import certs + +MAX_EXEC_TIME: typing.Final[int] = 2 # Max execution time for the connections + + +class TunnelServer(threading.Thread): + listening: threading.Event + port: int + error: bool + error_msg: typing.Optional[str] + ticket: typing.Optional[bytes] + wait_time: int + + def __init__(self, wait_time: int = MAX_EXEC_TIME) -> None: + super().__init__() + self.wait_time = wait_time + self.listening = threading.Event() + self.port = 0 + self.error = False + self.error_msg = None + self.ticket = None + + def listen(self, server: socket.socket) -> socket.socket: + server.settimeout(self.wait_time) # So the task never gets stuck, this is for testing purposes only + server.bind(('localhost', 0)) + self.port = server.getsockname()[1] + server.listen(1) + self.listening.set() + conn, _addr = server.accept() + conn.settimeout(self.wait_time) # So the task never gets stuck, this is for testing purposes only + return conn + + def read_header(self, conn: socket.socket) -> None: + header = conn.recv(len(consts.HANDSHAKE_V1), socket.MSG_WAITALL) + if header != consts.HANDSHAKE_V1: + raise Exception(f'Invalid header: {header}') + + def process_command(self, conn: socket.socket) -> None: + with certs.server_ssl_context() as ssl_context: + # Upgrade connection to SSL + conn = ssl_context.wrap_socket(conn, server_side=True) + + # Read command, 4 bytes (consts.CMD_OPEN or consts.CMD_TEST) + command = conn.recv(4) # conn is now ssl socket, does not allows non-zero flags + if command == consts.CMD_OPEN: + # Read the ticket + self.ticket = conn.recv(consts.TICKET_LENGTH) + conn.send(consts.RESPONSE_OK) + elif command == consts.CMD_TEST: + # Just return OK + conn.send(consts.RESPONSE_OK) + else: + self.error = True + self.error_msg = f'Invalid command: {command}' + + def run(self) -> None: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server: + conn = self.listen(server) + self.read_header(conn) + self.process_command(conn) + except Exception as e: + self.error = True + self.error_msg = f'Exception: {e}' + + def wait_for_listener(self) -> None: + self.listening.wait() + if self.error: + raise Exception(f'Error starting server: {self.error_msg}')