diff --git a/src/socketio/async_simple_client.py b/src/socketio/async_simple_client.py index c6cd4fc1..adac6ead 100644 --- a/src/socketio/async_simple_client.py +++ b/src/socketio/async_simple_client.py @@ -12,6 +12,8 @@ class AsyncSimpleClient: The positional and keyword arguments given in the constructor are passed to the underlying :func:`socketio.AsyncClient` object. """ + client_class = AsyncClient + def __init__(self, *args, **kwargs): self.client_args = args self.client_kwargs = kwargs @@ -60,7 +62,8 @@ async def connect(self, url, headers={}, auth=None, transports=None, self.namespace = namespace self.input_buffer = [] self.input_event.clear() - self.client = AsyncClient(*self.client_args, **self.client_kwargs) + self.client = self.client_class( + *self.client_args, **self.client_kwargs) @self.client.event(namespace=self.namespace) def connect(): # pragma: no cover diff --git a/src/socketio/simple_client.py b/src/socketio/simple_client.py index 67791477..3f046b4b 100644 --- a/src/socketio/simple_client.py +++ b/src/socketio/simple_client.py @@ -12,6 +12,8 @@ class SimpleClient: The positional and keyword arguments given in the constructor are passed to the underlying :func:`socketio.Client` object. """ + client_class = Client + def __init__(self, *args, **kwargs): self.client_args = args self.client_kwargs = kwargs @@ -58,7 +60,8 @@ def connect(self, url, headers={}, auth=None, transports=None, self.namespace = namespace self.input_buffer = [] self.input_event.clear() - self.client = Client(*self.client_args, **self.client_kwargs) + self.client = self.client_class( + *self.client_args, **self.client_kwargs) @self.client.event(namespace=self.namespace) def connect(): # pragma: no cover diff --git a/tests/async/test_simple_client.py b/tests/async/test_simple_client.py index 08926922..bfe2a90f 100644 --- a/tests/async/test_simple_client.py +++ b/tests/async/test_simple_client.py @@ -16,46 +16,51 @@ async def test_constructor(self): assert not client.connected async def test_connect(self): + mock_client = mock.MagicMock() + original_client_class = AsyncSimpleClient.client_class + AsyncSimpleClient.client_class = mock_client + client = AsyncSimpleClient(123, a='b') - with mock.patch('socketio.async_simple_client.AsyncClient') \ - as mock_client: + mock_client.return_value.connect = mock.AsyncMock() + + await client.connect('url', headers='h', auth='a', transports='t', + namespace='n', socketio_path='s', + wait_timeout='w') + mock_client.assert_called_once_with(123, a='b') + assert client.client == mock_client() + mock_client().connect.assert_awaited_once_with( + 'url', headers='h', auth='a', transports='t', + namespaces=['n'], socketio_path='s', wait_timeout='w') + mock_client().event.call_count == 3 + mock_client().on.assert_called_once_with('*', namespace='n') + assert client.namespace == 'n' + assert not client.input_event.is_set() + + AsyncSimpleClient.client_class = original_client_class + + async def test_connect_context_manager(self): + mock_client = mock.MagicMock() + original_client_class = AsyncSimpleClient.client_class + AsyncSimpleClient.client_class = mock_client + + async with AsyncSimpleClient(123, a='b') as client: mock_client.return_value.connect = mock.AsyncMock() - await client.connect('url', headers='h', auth='a', transports='t', - namespace='n', socketio_path='s', - wait_timeout='w') + await client.connect('url', headers='h', auth='a', + transports='t', namespace='n', + socketio_path='s', wait_timeout='w') mock_client.assert_called_once_with(123, a='b') assert client.client == mock_client() mock_client().connect.assert_awaited_once_with( 'url', headers='h', auth='a', transports='t', namespaces=['n'], socketio_path='s', wait_timeout='w') mock_client().event.call_count == 3 - mock_client().on.assert_called_once_with('*', namespace='n') + mock_client().on.assert_called_once_with( + '*', namespace='n') assert client.namespace == 'n' assert not client.input_event.is_set() - async def test_connect_context_manager(self): - async def _t(): - async with AsyncSimpleClient(123, a='b') as client: - with mock.patch('socketio.async_simple_client.AsyncClient') \ - as mock_client: - mock_client.return_value.connect = mock.AsyncMock() - - await client.connect('url', headers='h', auth='a', - transports='t', namespace='n', - socketio_path='s', wait_timeout='w') - mock_client.assert_called_once_with(123, a='b') - assert client.client == mock_client() - mock_client().connect.assert_awaited_once_with( - 'url', headers='h', auth='a', transports='t', - namespaces=['n'], socketio_path='s', wait_timeout='w') - mock_client().event.call_count == 3 - mock_client().on.assert_called_once_with( - '*', namespace='n') - assert client.namespace == 'n' - assert not client.input_event.is_set() - - await _t() + AsyncSimpleClient.client_class = original_client_class async def test_connect_twice(self): client = AsyncSimpleClient(123, a='b') diff --git a/tests/common/test_simple_client.py b/tests/common/test_simple_client.py index 42790573..b17afbcc 100644 --- a/tests/common/test_simple_client.py +++ b/tests/common/test_simple_client.py @@ -14,10 +14,34 @@ def test_constructor(self): assert not client.connected def test_connect(self): + mock_client = mock.MagicMock() + original_client_class = SimpleClient.client_class + SimpleClient.client_class = mock_client + client = SimpleClient(123, a='b') - with mock.patch('socketio.simple_client.Client') as mock_client: + client.connect('url', headers='h', auth='a', transports='t', + namespace='n', socketio_path='s', wait_timeout='w') + mock_client.assert_called_once_with(123, a='b') + assert client.client == mock_client() + mock_client().connect.assert_called_once_with( + 'url', headers='h', auth='a', transports='t', + namespaces=['n'], socketio_path='s', wait_timeout='w') + mock_client().event.call_count == 3 + mock_client().on.assert_called_once_with('*', namespace='n') + assert client.namespace == 'n' + assert not client.input_event.is_set() + + SimpleClient.client_class = original_client_class + + def test_connect_context_manager(self): + mock_client = mock.MagicMock() + original_client_class = SimpleClient.client_class + SimpleClient.client_class = mock_client + + with SimpleClient(123, a='b') as client: client.connect('url', headers='h', auth='a', transports='t', - namespace='n', socketio_path='s', wait_timeout='w') + namespace='n', socketio_path='s', + wait_timeout='w') mock_client.assert_called_once_with(123, a='b') assert client.client == mock_client() mock_client().connect.assert_called_once_with( @@ -28,21 +52,7 @@ def test_connect(self): assert client.namespace == 'n' assert not client.input_event.is_set() - def test_connect_context_manager(self): - with SimpleClient(123, a='b') as client: - with mock.patch('socketio.simple_client.Client') as mock_client: - client.connect('url', headers='h', auth='a', transports='t', - namespace='n', socketio_path='s', - wait_timeout='w') - mock_client.assert_called_once_with(123, a='b') - assert client.client == mock_client() - mock_client().connect.assert_called_once_with( - 'url', headers='h', auth='a', transports='t', - namespaces=['n'], socketio_path='s', wait_timeout='w') - mock_client().event.call_count == 3 - mock_client().on.assert_called_once_with('*', namespace='n') - assert client.namespace == 'n' - assert not client.input_event.is_set() + SimpleClient.client_class = original_client_class def test_connect_twice(self): client = SimpleClient(123, a='b')