From ead3c244ae3ed284a3c04478381cf02ab1ebaac6 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 26 Dec 2024 16:04:15 +0800 Subject: [PATCH] Encode Kafka partition key --- integration/test_kafka_integration.py | 14 +++++++++----- storey/targets.py | 8 +++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/integration/test_kafka_integration.py b/integration/test_kafka_integration.py index a21aa080..7b114f0b 100644 --- a/integration/test_kafka_integration.py +++ b/integration/test_kafka_integration.py @@ -92,13 +92,13 @@ def test_kafka_target(kafka_topic_setup_teardown): assert record.value.decode("UTF-8") == json.dumps(event.body, default=str) -async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown): +async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key): kafka_consumer = kafka_topic_setup_teardown controller = build_flow( [ AsyncEmitSource(), - KafkaTarget(kafka_brokers, topic, sharding_func=lambda _: 0, full_event=True), + KafkaTarget(kafka_brokers, topic, sharding_func=lambda _: partition_key, full_event=True), ] ).run() events = [] @@ -115,7 +115,10 @@ async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardo record = next(kafka_consumer) if event.key is None: if event.key is None: - assert record.key is None + if isinstance(partition_key, int): + assert record.key is None + else: + assert record.key.decode("UTF-8") == partition_key else: assert record.key.decode("UTF-8") == event.key readback_records.append(json.loads(record.value.decode("UTF-8"))) @@ -143,5 +146,6 @@ async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardo not kafka_brokers, reason="KAFKA_BROKERS must be defined to run kafka tests", ) -def test_async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown): - asyncio.run(async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown)) +@pytest.mark.parametrize("partition_key", [0, "some_string"]) +def test_async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key): + asyncio.run(async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key)) diff --git a/storey/targets.py b/storey/targets.py index 75b3c367..4bca9fcb 100644 --- a/storey/targets.py +++ b/storey/targets.py @@ -1345,9 +1345,7 @@ async def _do(self, event): self._producer.close() return await self._do_downstream(_termination_obj) else: - key = None - if event.key is not None: - key = stringify_key(event.key).encode("UTF-8") + key = event.key record = self._event_to_writer_entry(event) if self._full_event: record = wrap_event_for_serialization(event, record) @@ -1359,6 +1357,10 @@ async def _do(self, event): partition = sharding_func_result else: key = sharding_func_result + + if key is not None: + key = stringify_key(key).encode("UTF-8") + future = self._producer.send(self._topic, record, key, partition=partition) # Prevent garbage collection of event until persisted to kafka future.add_callback(lambda x: event)