Skip to content

Commit

Permalink
add multithreaded and multiprocessed banchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
pedohorse committed Sep 3, 2024
1 parent 40ecc2b commit edf8676
Showing 1 changed file with 89 additions and 9 deletions.
98 changes: 89 additions & 9 deletions tests/tests_net_messages/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import asyncio
import random
import string
import time
import threading
import multiprocessing
from unittest import IsolatedAsyncioTestCase, skip
from lifeblood.logging import get_logger, set_default_loglevel
from lifeblood.nethelpers import get_localhost
Expand Down Expand Up @@ -32,36 +34,113 @@ async def process_message(self, message: Message, client: MessageClient):
self.test_messages_count += 1


class ThreadedFoo(threading.Thread):
def __init__(self, server: TcpMessageProcessor):
class FooRunner:
def start(self):
raise NotImplementedError()

def stop(self):
raise NotImplementedError()

def join(self):
raise NotImplementedError()

def get_message_count(self) -> int:
raise NotImplementedError()


class ThreadedFoo(threading.Thread, FooRunner):
def __init__(self, server: NoopMessageServer):
super().__init__()
self.__stop = False
self.__stop = threading.Event()
self.__ready = threading.Event()
self.__server = server

def run(self):
asyncio.run(self.async_run())

def start(self):
super().start()
self.__ready.wait()

async def async_run(self):
await self.__server.start()
self.__ready.set()
while True:
await asyncio.sleep(1)
if self.__stop:
if self.__stop.is_set():
break

self.__server.stop()
await self.__server.wait_till_stops()

def stop(self):
# crude crude crude
self.__stop = True
self.__stop.set()

def get_message_count(self) -> int:
return self.__server.test_messages_count


class ProcessedFoo(FooRunner):
def __init__(self, server: NoopMessageServer):
super().__init__()
self.__server = server
ctx = multiprocessing.get_context('spawn')
self.__stop = ctx.Event()
self.__value = ctx.Value('i', -1)
self.__ready = ctx.Event()
self.__proc = ctx.Process(target=self.body)

def start(self):
self.__proc.start()
self.__ready.wait()

def body(self):
asyncio.run(self.async_run())

async def async_run(self):
print('another process started')
await self.__server.start()
self.__ready.set()
print('another process server started')
while True:
await asyncio.sleep(1)
if self.__stop.is_set():
break

self.__server.stop()
await self.__server.wait_till_stops()
self.__value.value = self.__server.test_messages_count

def stop(self):
self.__stop.set()

def join(self):
self.__proc.join()

def get_message_count(self) -> int:
return self.__value.value


class TestBenchmarkSendReceive(IsolatedAsyncioTestCase):
async def test1(self):
@skip("no reason to benchmark on slow machines")
async def test_threaded(self):
await self.helper_test(ThreadedFoo)

@skip("no reason to benchmark on slow machines")
async def test_proc(self):
await self.helper_test(ProcessedFoo)

async def helper_test(self, foo_factory: Callable[[NoopMessageServer], FooRunner]):
"""
runs 2 servers
starts X clients asyncio coroutines on one of the servers, each sends Y messages.
average per-message time is then calculated
"""
data = ''.join(random.choice(string.ascii_letters) for _ in range(16000)).encode('latin1')
server1 = NoopMessageServer((get_localhost(), 28385))
server2 = NoopMessageServer((get_localhost(), 28386))
server1_runner = ThreadedFoo(server1)
server1_runner = foo_factory(server1)
server1_runner.start()
await server2.start()
pure_send_time = 0.0
Expand Down Expand Up @@ -90,5 +169,6 @@ async def test_foo():
server1_runner.stop()
await server2.wait_till_stops()
server1_runner.join()
print(f'total go {server1.test_messages_count} in {total_time}s (pure send: {pure_send_time}s, avg {server1.test_messages_count/total_time} (pure: {server1.test_messages_count/pure_send_time}) msg/s')
self.assertEqual(total_clients * messages_per_client, server1.test_messages_count)
s1_message_count = server1_runner.get_message_count()
print(f'threaded total go {s1_message_count} in {total_time}s (pure send: {pure_send_time}s, avg {s1_message_count/total_time} (pure: {s1_message_count/pure_send_time}) msg/s')
self.assertEqual(total_clients * messages_per_client, s1_message_count)

0 comments on commit edf8676

Please sign in to comment.