Skip to content

Commit

Permalink
Allow custom client subclasses to be used in SimpleClient and AsyncSi…
Browse files Browse the repository at this point in the history
…mpleClient (Fixes #1432)
  • Loading branch information
miguelgrinberg committed Feb 5, 2025
1 parent a598a55 commit 7605630
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 47 deletions.
5 changes: 4 additions & 1 deletion src/socketio/async_simple_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class AsyncSimpleClient:
The positional and keyword arguments given in the constructor are passed
to the underlying :func:`socketio.AsyncClient` object.
"""
client_class = AsyncClient

def __init__(self, *args, **kwargs):
self.client_args = args
self.client_kwargs = kwargs
Expand Down Expand Up @@ -60,7 +62,8 @@ async def connect(self, url, headers={}, auth=None, transports=None,
self.namespace = namespace
self.input_buffer = []
self.input_event.clear()
self.client = AsyncClient(*self.client_args, **self.client_kwargs)
self.client = self.client_class(
*self.client_args, **self.client_kwargs)

@self.client.event(namespace=self.namespace)
def connect(): # pragma: no cover
Expand Down
5 changes: 4 additions & 1 deletion src/socketio/simple_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class SimpleClient:
The positional and keyword arguments given in the constructor are passed
to the underlying :func:`socketio.Client` object.
"""
client_class = Client

def __init__(self, *args, **kwargs):
self.client_args = args
self.client_kwargs = kwargs
Expand Down Expand Up @@ -58,7 +60,8 @@ def connect(self, url, headers={}, auth=None, transports=None,
self.namespace = namespace
self.input_buffer = []
self.input_event.clear()
self.client = Client(*self.client_args, **self.client_kwargs)
self.client = self.client_class(
*self.client_args, **self.client_kwargs)

@self.client.event(namespace=self.namespace)
def connect(): # pragma: no cover
Expand Down
61 changes: 33 additions & 28 deletions tests/async/test_simple_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,51 @@ async def test_constructor(self):
assert not client.connected

async def test_connect(self):
mock_client = mock.MagicMock()
original_client_class = AsyncSimpleClient.client_class
AsyncSimpleClient.client_class = mock_client

client = AsyncSimpleClient(123, a='b')
with mock.patch('socketio.async_simple_client.AsyncClient') \
as mock_client:
mock_client.return_value.connect = mock.AsyncMock()

await client.connect('url', headers='h', auth='a', transports='t',
namespace='n', socketio_path='s',
wait_timeout='w')
mock_client.assert_called_once_with(123, a='b')
assert client.client == mock_client()
mock_client().connect.assert_awaited_once_with(
'url', headers='h', auth='a', transports='t',
namespaces=['n'], socketio_path='s', wait_timeout='w')
mock_client().event.call_count == 3
mock_client().on.assert_called_once_with('*', namespace='n')
assert client.namespace == 'n'
assert not client.input_event.is_set()

AsyncSimpleClient.client_class = original_client_class

async def test_connect_context_manager(self):
mock_client = mock.MagicMock()
original_client_class = AsyncSimpleClient.client_class
AsyncSimpleClient.client_class = mock_client

async with AsyncSimpleClient(123, a='b') as client:
mock_client.return_value.connect = mock.AsyncMock()

await client.connect('url', headers='h', auth='a', transports='t',
namespace='n', socketio_path='s',
wait_timeout='w')
await client.connect('url', headers='h', auth='a',
transports='t', namespace='n',
socketio_path='s', wait_timeout='w')
mock_client.assert_called_once_with(123, a='b')
assert client.client == mock_client()
mock_client().connect.assert_awaited_once_with(
'url', headers='h', auth='a', transports='t',
namespaces=['n'], socketio_path='s', wait_timeout='w')
mock_client().event.call_count == 3
mock_client().on.assert_called_once_with('*', namespace='n')
mock_client().on.assert_called_once_with(
'*', namespace='n')
assert client.namespace == 'n'
assert not client.input_event.is_set()

async def test_connect_context_manager(self):
async def _t():
async with AsyncSimpleClient(123, a='b') as client:
with mock.patch('socketio.async_simple_client.AsyncClient') \
as mock_client:
mock_client.return_value.connect = mock.AsyncMock()

await client.connect('url', headers='h', auth='a',
transports='t', namespace='n',
socketio_path='s', wait_timeout='w')
mock_client.assert_called_once_with(123, a='b')
assert client.client == mock_client()
mock_client().connect.assert_awaited_once_with(
'url', headers='h', auth='a', transports='t',
namespaces=['n'], socketio_path='s', wait_timeout='w')
mock_client().event.call_count == 3
mock_client().on.assert_called_once_with(
'*', namespace='n')
assert client.namespace == 'n'
assert not client.input_event.is_set()

await _t()
AsyncSimpleClient.client_class = original_client_class

async def test_connect_twice(self):
client = AsyncSimpleClient(123, a='b')
Expand Down
44 changes: 27 additions & 17 deletions tests/common/test_simple_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,34 @@ def test_constructor(self):
assert not client.connected

def test_connect(self):
mock_client = mock.MagicMock()
original_client_class = SimpleClient.client_class
SimpleClient.client_class = mock_client

client = SimpleClient(123, a='b')
with mock.patch('socketio.simple_client.Client') as mock_client:
client.connect('url', headers='h', auth='a', transports='t',
namespace='n', socketio_path='s', wait_timeout='w')
mock_client.assert_called_once_with(123, a='b')
assert client.client == mock_client()
mock_client().connect.assert_called_once_with(
'url', headers='h', auth='a', transports='t',
namespaces=['n'], socketio_path='s', wait_timeout='w')
mock_client().event.call_count == 3
mock_client().on.assert_called_once_with('*', namespace='n')
assert client.namespace == 'n'
assert not client.input_event.is_set()

SimpleClient.client_class = original_client_class

def test_connect_context_manager(self):
mock_client = mock.MagicMock()
original_client_class = SimpleClient.client_class
SimpleClient.client_class = mock_client

with SimpleClient(123, a='b') as client:
client.connect('url', headers='h', auth='a', transports='t',
namespace='n', socketio_path='s', wait_timeout='w')
namespace='n', socketio_path='s',
wait_timeout='w')
mock_client.assert_called_once_with(123, a='b')
assert client.client == mock_client()
mock_client().connect.assert_called_once_with(
Expand All @@ -28,21 +52,7 @@ def test_connect(self):
assert client.namespace == 'n'
assert not client.input_event.is_set()

def test_connect_context_manager(self):
with SimpleClient(123, a='b') as client:
with mock.patch('socketio.simple_client.Client') as mock_client:
client.connect('url', headers='h', auth='a', transports='t',
namespace='n', socketio_path='s',
wait_timeout='w')
mock_client.assert_called_once_with(123, a='b')
assert client.client == mock_client()
mock_client().connect.assert_called_once_with(
'url', headers='h', auth='a', transports='t',
namespaces=['n'], socketio_path='s', wait_timeout='w')
mock_client().event.call_count == 3
mock_client().on.assert_called_once_with('*', namespace='n')
assert client.namespace == 'n'
assert not client.input_event.is_set()
SimpleClient.client_class = original_client_class

def test_connect_twice(self):
client = SimpleClient(123, a='b')
Expand Down

0 comments on commit 7605630

Please sign in to comment.