diff --git a/tests/tests_net_messages/test_benchmark.py b/tests/tests_net_messages/test_benchmark.py index 9387970e..8ca36ffc 100644 --- a/tests/tests_net_messages/test_benchmark.py +++ b/tests/tests_net_messages/test_benchmark.py @@ -1,8 +1,10 @@ +import os import asyncio import random import string import time import threading +import multiprocessing from unittest import IsolatedAsyncioTestCase, skip from lifeblood.logging import get_logger, set_default_loglevel from lifeblood.nethelpers import get_localhost @@ -32,20 +34,40 @@ async def process_message(self, message: Message, client: MessageClient): self.test_messages_count += 1 -class ThreadedFoo(threading.Thread): - def __init__(self, server: TcpMessageProcessor): +class FooRunner: + def start(self): + raise NotImplementedError() + + def stop(self): + raise NotImplementedError() + + def join(self): + raise NotImplementedError() + + def get_message_count(self) -> int: + raise NotImplementedError() + + +class ThreadedFoo(threading.Thread, FooRunner): + def __init__(self, server: NoopMessageServer): super().__init__() - self.__stop = False + self.__stop = threading.Event() + self.__ready = threading.Event() self.__server = server def run(self): asyncio.run(self.async_run()) + def start(self): + super().start() + self.__ready.wait() + async def async_run(self): await self.__server.start() + self.__ready.set() while True: await asyncio.sleep(1) - if self.__stop: + if self.__stop.is_set(): break self.__server.stop() @@ -53,15 +75,72 @@ async def async_run(self): def stop(self): # crude crude crude - self.__stop = True + self.__stop.set() + + def get_message_count(self) -> int: + return self.__server.test_messages_count + + +class ProcessedFoo(FooRunner): + def __init__(self, server: NoopMessageServer): + super().__init__() + self.__server = server + ctx = multiprocessing.get_context('spawn') + self.__stop = ctx.Event() + self.__value = ctx.Value('i', -1) + self.__ready = ctx.Event() + self.__proc = ctx.Process(target=self.body) + + def start(self): + self.__proc.start() + self.__ready.wait() + + def body(self): + asyncio.run(self.async_run()) + + async def async_run(self): + print('another process started') + await self.__server.start() + self.__ready.set() + print('another process server started') + while True: + await asyncio.sleep(1) + if self.__stop.is_set(): + break + + self.__server.stop() + await self.__server.wait_till_stops() + self.__value.value = self.__server.test_messages_count + + def stop(self): + self.__stop.set() + + def join(self): + self.__proc.join() + + def get_message_count(self) -> int: + return self.__value.value class TestBenchmarkSendReceive(IsolatedAsyncioTestCase): - async def test1(self): + @skip("no reason to benchmark on slow machines") + async def test_threaded(self): + await self.helper_test(ThreadedFoo) + + @skip("no reason to benchmark on slow machines") + async def test_proc(self): + await self.helper_test(ProcessedFoo) + + async def helper_test(self, foo_factory: Callable[[NoopMessageServer], FooRunner]): + """ + runs 2 servers + starts X clients asyncio coroutines on one of the servers, each sends Y messages. + average per-message time is then calculated + """ data = ''.join(random.choice(string.ascii_letters) for _ in range(16000)).encode('latin1') server1 = NoopMessageServer((get_localhost(), 28385)) server2 = NoopMessageServer((get_localhost(), 28386)) - server1_runner = ThreadedFoo(server1) + server1_runner = foo_factory(server1) server1_runner.start() await server2.start() pure_send_time = 0.0 @@ -90,5 +169,6 @@ async def test_foo(): server1_runner.stop() await server2.wait_till_stops() server1_runner.join() - print(f'total go {server1.test_messages_count} in {total_time}s (pure send: {pure_send_time}s, avg {server1.test_messages_count/total_time} (pure: {server1.test_messages_count/pure_send_time}) msg/s') - self.assertEqual(total_clients * messages_per_client, server1.test_messages_count) + s1_message_count = server1_runner.get_message_count() + print(f'threaded total go {s1_message_count} in {total_time}s (pure send: {pure_send_time}s, avg {s1_message_count/total_time} (pure: {s1_message_count/pure_send_time}) msg/s') + self.assertEqual(total_clients * messages_per_client, s1_message_count)