diff --git a/exasol_advanced_analytics_framework/udf_communication/messages.py b/exasol_advanced_analytics_framework/udf_communication/messages.py index 6285faf0..f83593cf 100644 --- a/exasol_advanced_analytics_framework/udf_communication/messages.py +++ b/exasol_advanced_analytics_framework/udf_communication/messages.py @@ -37,6 +37,11 @@ class Ping(BaseModel, frozen=True): class Stop(BaseModel, frozen=True): message_type: Literal["Stop"] = "Stop" +class PrepareToStop(BaseModel, frozen=True): + message_type: Literal["PrepareToStop"] = "PrepareToStop" + +class IsReadyToStop(BaseModel, frozen=True): + message_type: Literal["IsReadyToStop"] = "IsReadyToStop" class Payload(BaseModel, frozen=True): message_type: Literal["Payload"] = "Payload" @@ -69,6 +74,8 @@ class Message(BaseModel, frozen=True): AcknowledgeRegisterPeer, RegisterPeerComplete, Stop, + PrepareToStop, + IsReadyToStop, Payload, MyConnectionInfo, PeerIsReadyToReceive, diff --git a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/acknowledge_register_peer_sender.py b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/acknowledge_register_peer_sender.py index ca069fc2..d0a79db6 100644 --- a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/acknowledge_register_peer_sender.py +++ b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/acknowledge_register_peer_sender.py @@ -56,6 +56,11 @@ def _should_we_send(self): result = is_time and not self._finished return result + def is_ready_to_stop(self): + result = (self._finished and self._needs_to_send_for_peer) or not self._needs_to_send_for_peer + self._logger.debug("is_ready_to_stop", finished=self._finished, is_ready_to_stop=result) + return result + class AcknowledgeRegisterPeerSenderFactory(): def create(self, diff --git a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_listener_interface.py b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_listener_interface.py index bc2e73dd..ac05a1ed 100644 --- a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_listener_interface.py +++ b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_listener_interface.py @@ -7,6 +7,7 @@ from exasol_advanced_analytics_framework.udf_communication.connection_info import ConnectionInfo from exasol_advanced_analytics_framework.udf_communication.ip_address import IPAddress +from exasol_advanced_analytics_framework.udf_communication.messages import Message, IsReadyToStop, Stop, PrepareToStop from exasol_advanced_analytics_framework.udf_communication.peer import Peer from exasol_advanced_analytics_framework.udf_communication.peer_communicator.background_listener_thread import \ BackgroundListenerThread @@ -25,6 +26,7 @@ class BackgroundListenerInterface: def __init__(self, name: str, + number_of_peers: int, socket_factory: SocketFactory, listen_ip: IPAddress, group_identifier: str, @@ -42,8 +44,10 @@ def __init__(self, out_control_socket_address = self._create_out_control_socket(socket_factory) in_control_socket_address = self._create_in_control_socket(socket_factory) self._my_connection_info: Optional[ConnectionInfo] = None + self._is_ready_to_stop = False self._background_listener_run = BackgroundListenerThread( name=self._name, + number_of_peers=number_of_peers, socket_factory=socket_factory, listen_ip=listen_ip, group_identifier=group_identifier, @@ -95,15 +99,24 @@ def receive_messages(self, timeout_in_milliseconds: Optional[int] = 0) -> Iterat timeout_in_ms=timeout_in_milliseconds): message = None try: - message = self._out_control_socket.receive() - message_obj: messages.Message = deserialize_message(message, messages.Message) - specific_message_obj = message_obj.__root__ timeout_in_milliseconds = 0 - yield specific_message_obj + message = self._out_control_socket.receive() + message_obj: Message = deserialize_message(message, Message) + yield from self._handle_message(message_obj) except Exception as e: self._logger.exception("Exception", raw_message=message) - def close(self): + def _handle_message(self, message_obj: Message) -> Message: + specific_message_obj = message_obj.__root__ + if isinstance(specific_message_obj, IsReadyToStop): + self._is_ready_to_stop = True + else: + yield message_obj + + def is_ready_to_stop(self): + return self._is_ready_to_stop + + def stop(self): self._logger.info("start") self._send_stop() self._thread.join() @@ -112,6 +125,12 @@ def close(self): self._logger.info("end") def _send_stop(self): - self._logger.info("_send_stop") - stop_message = messages.Stop() - self._in_control_socket.send(serialize_message(stop_message)) + self._in_control_socket.send(serialize_message(Stop())) + + def prepare_to_stop(self): + self._logger.info("start") + self._send_prepare_to_stop() + self._logger.info("end") + + def _send_prepare_to_stop(self): + self._in_control_socket.send(serialize_message(PrepareToStop())) diff --git a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_listener_thread.py b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_listener_thread.py index a8448aab..e51f8c60 100644 --- a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_listener_thread.py +++ b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_listener_thread.py @@ -8,6 +8,7 @@ from exasol_advanced_analytics_framework.udf_communication import messages from exasol_advanced_analytics_framework.udf_communication.connection_info import ConnectionInfo from exasol_advanced_analytics_framework.udf_communication.ip_address import IPAddress, Port +from exasol_advanced_analytics_framework.udf_communication.messages import IsReadyToStop, PrepareToStop from exasol_advanced_analytics_framework.udf_communication.peer import Peer from exasol_advanced_analytics_framework.udf_communication.peer_communicator.abort_timeout_sender import \ AbortTimeoutSenderFactory @@ -71,10 +72,12 @@ def create_background_peer_state_builder() -> BackgroundPeerStateBuilder: class BackgroundListenerThread: class Status(enum.Enum): RUNNING = enum.auto() + PREPARE_TO_STOP = enum.auto() STOPPED = enum.auto() def __init__(self, name: str, + number_of_peers: int, socket_factory: SocketFactory, listen_ip: IPAddress, group_identifier: str, @@ -84,6 +87,7 @@ def __init__(self, config: PeerCommunicatorConfig, trace_logging: bool, background_peer_state_factory: BackgroundPeerStateBuilder = create_background_peer_state_builder()): + self._number_of_peers = number_of_peers self._config = config self._background_peer_state_factory = background_peer_state_factory self._register_peer_connection: Optional[RegisterPeerConnection] = None @@ -112,16 +116,16 @@ def run(self): try: self._run_message_loop() finally: - self._close() + self._stop() - def _close(self): + def _stop(self): self._logger.info("start") if self._register_peer_connection is not None: self._register_peer_connection.close() self._out_control_socket.close(linger=0) self._in_control_socket.close(linger=0) for peer_state in self._peer_state.values(): - peer_state.close() + peer_state.stop() self._listener_socket.close(linger=0) self._logger.info("end") @@ -146,31 +150,51 @@ def _create_poller(self): def _run_message_loop(self): try: - while self._status == BackgroundListenerThread.Status.RUNNING: - poll = self.poller.poll(timeout_in_ms=self._config.poll_timeout_in_ms) - if self._in_control_socket in poll and PollerFlag.POLLIN in poll[self._in_control_socket]: - message = self._in_control_socket.receive() - self._status = self._handle_control_message(message) - if self._listener_socket in poll and PollerFlag.POLLIN in poll[self._listener_socket]: - message = self._listener_socket.receive_multipart() - self._handle_listener_message(message) - if self._status == BackgroundListenerThread.Status.RUNNING: - for peer_state in self._peer_state.values(): - peer_state.resend_if_necessary() + while self._status != BackgroundListenerThread.Status.STOPPED: + self._handle_message() + self._try_send() + self._check_is_ready_to_stop() except Exception as e: self._logger.exception("Exception in message loop") + def _check_is_ready_to_stop(self): + if self._status == BackgroundListenerThread.Status.PREPARE_TO_STOP: + if self._is_ready_to_stop(): + self._out_control_socket.send(serialize_message(IsReadyToStop())) + + def _is_ready_to_stop(self): + peers_status = [peer_state.is_ready_to_stop() + for peer_state in self._peer_state.values()] + is_ready_to_stop = all(peers_status) and len(peers_status) == self._number_of_peers - 1 + return is_ready_to_stop + + def _try_send(self): + if self._status != BackgroundListenerThread.Status.STOPPED: + for peer_state in self._peer_state.values(): + peer_state.try_send() + + def _handle_message(self): + poll = self.poller.poll(timeout_in_ms=self._config.poll_timeout_in_ms) + if self._in_control_socket in poll and PollerFlag.POLLIN in poll[self._in_control_socket]: + message = self._in_control_socket.receive() + self._status = self._handle_control_message(message) + if self._listener_socket in poll and PollerFlag.POLLIN in poll[self._listener_socket]: + message = self._listener_socket.receive_multipart() + self._handle_listener_message(message) + def _handle_control_message(self, message: bytes) -> Status: try: message_obj: messages.Message = deserialize_message(message, messages.Message) specific_message_obj = message_obj.__root__ if isinstance(specific_message_obj, messages.Stop): return BackgroundListenerThread.Status.STOPPED + elif isinstance(specific_message_obj, PrepareToStop): + return BackgroundListenerThread.Status.PREPARE_TO_STOP elif isinstance(specific_message_obj, messages.RegisterPeer): if self._is_register_peer_message_allowed_as_control_message(): self._handle_register_peer_message(specific_message_obj) else: - self._logger.error("RegisterPeerMessage message not allowed", + self._logger.error("RegisterPeer message not allowed", message_obj=specific_message_obj.dict()) else: self._logger.error("Unknown message type", message_obj=specific_message_obj.dict()) @@ -180,11 +204,11 @@ def _handle_control_message(self, message: bytes) -> Status: def _is_register_peer_message_allowed_as_control_message(self): return ( - ( - self._config.forward_register_peer_config.is_enabled - and self._config.forward_register_peer_config.is_leader - ) - or not self._config.forward_register_peer_config.is_enabled + ( + self._config.forward_register_peer_config.is_enabled + and self._config.forward_register_peer_config.is_leader + ) + or not self._config.forward_register_peer_config.is_enabled ) def _add_peer(self, @@ -227,7 +251,7 @@ def _handle_listener_message(self, message: List[Frame]): if self.is_register_peer_message_allowed_as_listener_message(): self._handle_register_peer_message(specific_message_obj) else: - logger.error("RegisterPeerMessage message not allowed", message_obj=specific_message_obj.dict()) + logger.error("RegisterPeer message not allowed", message_obj=specific_message_obj.dict()) elif isinstance(specific_message_obj, messages.AcknowledgeRegisterPeer): self._handle_acknowledge_register_peer_message(specific_message_obj) elif isinstance(specific_message_obj, messages.RegisterPeerComplete): @@ -276,7 +300,7 @@ def _handle_register_peer_message(self, message: messages.RegisterPeer): self._add_peer( message.peer, connection_establisher_behavior_config=ConnectionEstablisherBehaviorConfig( - acknowledge_register_peer=True, + acknowledge_register_peer=not self._config.forward_register_peer_config.is_leader, needs_register_peer_complete=True) ) return @@ -285,7 +309,7 @@ def _handle_register_peer_message(self, message: messages.RegisterPeer): message.peer, connection_establisher_behavior_config=ConnectionEstablisherBehaviorConfig( forward_register_peer=True, - acknowledge_register_peer=True, + acknowledge_register_peer=not self._config.forward_register_peer_config.is_leader, needs_register_peer_complete=True) ) @@ -313,12 +337,12 @@ def _create_register_peer_connection(self, message: messages.RegisterPeer): def _handle_acknowledge_register_peer_message(self, message: messages.AcknowledgeRegisterPeer): if self._register_peer_connection.successor != message.source: - self._logger.error("AcknowledgeRegisterPeerMessage message not from successor", message_obj=message.dict()) + self._logger.error("AcknowledgeRegisterPeer message not from successor", message_obj=message.dict()) peer = message.peer self._peer_state[peer].received_acknowledge_register_peer() def _handle_register_peer_complete_message(self, message: messages.RegisterPeerComplete): if self._register_peer_connection.predecessor != message.source: - self._logger.error("RegisterPeerCompleteMessage message not from predecssor", message_obj=message.dict()) + self._logger.error("RegisterPeerComplete message not from predecssor", message_obj=message.dict()) peer = message.peer self._peer_state[peer].received_register_peer_complete() diff --git a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_peer_state.py b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_peer_state.py index 7eb093ce..2746d910 100644 --- a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_peer_state.py +++ b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/background_peer_state.py @@ -42,7 +42,7 @@ def _create_receive_socket(self): receive_socket_address = get_peer_receive_socket_name(self._peer) self._receive_socket.bind(receive_socket_address) - def resend_if_necessary(self): + def try_send(self): self._logger.debug("resend_if_necessary") self._connection_establisher.try_send() @@ -61,5 +61,8 @@ def received_register_peer_complete(self): def forward_payload(self, frames: List[Frame]): self._receive_socket.send_multipart(frames) - def close(self): + def stop(self): self._receive_socket.close(linger=0) + + def is_ready_to_stop(self): + return self._connection_establisher.is_ready_to_stop() diff --git a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/connection_establisher.py b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/connection_establisher.py index de794869..9290b053 100644 --- a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/connection_establisher.py +++ b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/connection_establisher.py @@ -33,8 +33,8 @@ def __init__(self, register_peer_sender: RegisterPeerSender, synchronize_connection_sender: SynchronizeConnectionSender): self._synchronize_connection_sender = synchronize_connection_sender - self._register_peer_sender = register_peer_sender self._peer_is_ready_sender = peer_is_ready_sender + self._register_peer_sender = register_peer_sender self._acknowledge_register_peer_sender = acknowledge_register_peer_sender self._abort_timeout_sender = abort_timeout_sender self._register_peer_connection = register_peer_connection @@ -84,3 +84,15 @@ def try_send(self): self._abort_timeout_sender.try_send() self._peer_is_ready_sender.try_send() self._acknowledge_register_peer_sender.try_send() + + def is_ready_to_stop(self): + peer_is_ready_sender = self._peer_is_ready_sender.is_ready_to_stop() + register_peer_sender = self._register_peer_sender.is_ready_to_stop() + self._logger.debug("is_ready_to_stop", + peer_is_ready_sender=peer_is_ready_sender, + register_peer_sender=register_peer_sender, + ) + return ( + peer_is_ready_sender + and register_peer_sender + ) diff --git a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/frontend_peer_state.py b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/frontend_peer_state.py index 2ef43359..d490c18d 100644 --- a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/frontend_peer_state.py +++ b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/frontend_peer_state.py @@ -2,6 +2,9 @@ from typing import Optional, Generator, List from exasol_advanced_analytics_framework.udf_communication import messages +import structlog +from structlog.typing import FilteringBoundLogger + from exasol_advanced_analytics_framework.udf_communication.connection_info import ConnectionInfo from exasol_advanced_analytics_framework.udf_communication.peer import Peer from exasol_advanced_analytics_framework.udf_communication.peer_communicator.get_peer_receive_socket_name import \ @@ -10,6 +13,8 @@ from exasol_advanced_analytics_framework.udf_communication.socket_factory.abstract import SocketFactory, \ SocketType, Socket, Frame, PollerFlag +LOGGER: FilteringBoundLogger = structlog.getLogger() + class FrontendPeerState: @@ -54,10 +59,11 @@ def send(self, payload: List[Frame]): serialized_message = serialize_message(message) frame = self._socket_factory.create_frame(serialized_message) send_socket.send_multipart([frame] + payload) + send_socket.close(linger=100) def recv(self, timeout_in_milliseconds: Optional[int] = None) -> List[Frame]: if self._receive_socket.poll(flags=PollerFlag.POLLIN, timeout_in_ms=timeout_in_milliseconds) != 0: return self._receive_socket.receive_multipart() - def close(self): + def stop(self): self._receive_socket.close(linger=0) diff --git a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_communicator.py b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_communicator.py index f91ce8ac..5f266fef 100644 --- a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_communicator.py +++ b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_communicator.py @@ -1,6 +1,6 @@ import time from dataclasses import asdict -from typing import Optional, Dict, List +from typing import Optional, Dict, List, Callable import structlog from structlog.types import FilteringBoundLogger @@ -11,9 +11,9 @@ from exasol_advanced_analytics_framework.udf_communication.peer import Peer from exasol_advanced_analytics_framework.udf_communication.peer_communicator.background_listener_interface import \ BackgroundListenerInterface +from exasol_advanced_analytics_framework.udf_communication.peer_communicator.clock import Clock from exasol_advanced_analytics_framework.udf_communication.peer_communicator.forward_register_peer_config import \ ForwardRegisterPeerConfig -from exasol_advanced_analytics_framework.udf_communication.peer_communicator.clock import Clock from exasol_advanced_analytics_framework.udf_communication.peer_communicator.frontend_peer_state import \ FrontendPeerState from exasol_advanced_analytics_framework.udf_communication.peer_communicator.peer_communicator_config import \ @@ -60,6 +60,7 @@ def __init__(self, self._logger.info("init") self._background_listener = BackgroundListenerInterface( name=self._name, + number_of_peers=number_of_peers, socket_factory=self._socket_factory, listen_ip=listen_ip, group_identifier=self._group_identifier, @@ -73,18 +74,18 @@ def __init__(self, self._peer_states: Dict[Peer, FrontendPeerState] = {} def _handle_messages(self, timeout_in_milliseconds: Optional[int] = 0): - if self._are_all_peers_connected(): - return - - for message in self._background_listener.receive_messages(timeout_in_milliseconds): - if isinstance(message, messages.PeerIsReadyToReceive): - peer = message.peer + for message_obj in self._background_listener.receive_messages(timeout_in_milliseconds): + specific_message_obj = message_obj.__root__ + if isinstance(specific_message_obj, messages.PeerIsReadyToReceive): + peer = specific_message_obj.peer self._add_peer_state(peer) self._peer_states[peer].received_peer_is_ready_to_receive() - elif isinstance(message, messages.Timeout): - raise TimeoutError(message.reason) + elif isinstance(specific_message_obj, messages.Timeout): + raise TimeoutError(specific_message_obj.reason) else: - self._logger.error("Unknown message", message=message.dict()) + self._logger.error( + "Unknown message", + message_obj=specific_message_obj.dict()) def _add_peer_state(self, peer: Peer): if peer not in self._peer_states: @@ -94,9 +95,11 @@ def _add_peer_state(self, peer: Peer): peer=peer ) - def wait_for_peers(self, timeout_in_milliseconds: Optional[int] = None) -> bool: + def _wait_for_condition(self, condition: Callable[[], bool], + timeout_in_milliseconds: Optional[int] = None) -> bool: start_time_ns = time.monotonic_ns() - while True: + self._handle_messages(timeout_in_milliseconds=0) + while not condition(): if timeout_in_milliseconds is not None: handle_message_timeout_ms = _compute_handle_message_timeout(start_time_ns, timeout_in_milliseconds) if handle_message_timeout_ms < 0: @@ -104,10 +107,10 @@ def wait_for_peers(self, timeout_in_milliseconds: Optional[int] = None) -> bool: else: handle_message_timeout_ms = None self._handle_messages(timeout_in_milliseconds=handle_message_timeout_ms) - if self._are_all_peers_connected(): - break - connected = self._are_all_peers_connected() - return connected + return condition() + + def wait_for_peers(self, timeout_in_milliseconds: Optional[int] = None) -> bool: + return self._wait_for_condition(self._are_all_peers_connected, timeout_in_milliseconds) def peers(self, timeout_in_milliseconds: Optional[int] = None) -> Optional[List[Peer]]: self.wait_for_peers(timeout_in_milliseconds) @@ -163,13 +166,32 @@ def recv(self, peer: Peer, timeout_in_milliseconds: Optional[int] = None) -> Lis assert self.are_all_peers_connected() return self._peer_states[peer].recv(timeout_in_milliseconds) - def close(self): - self._logger.info("close") + def stop(self): + self._logger.info("stop") if self._background_listener is not None: - self._background_listener.close() + try: + self._stop_background_listener() + finally: + self._stop_peer_states() + + def _stop_background_listener(self): + self._logger.info("stop background_listener") + self._background_listener.prepare_to_stop() + try: + is_ready_to_stop = \ + self._wait_for_condition(self._background_listener.is_ready_to_stop, + timeout_in_milliseconds=self._config.close_timeout_in_ms) + if not is_ready_to_stop: + raise TimeoutError("Timeout expired, could not gracefully stop PeerCommuincator.") + finally: + self._background_listener.stop() self._background_listener = None - for peer_state in self._peer_states.values(): - peer_state.close() + + def _stop_peer_states(self): + self._logger.info("stop peer_states") + for peer_state_key in list(self._peer_states.keys()): + self._peer_states[peer_state_key].stop() + del self._peer_states[peer_state_key] def __del__(self): - self.close() + self.stop() diff --git a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_communicator_config.py b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_communicator_config.py index fd24ddfb..9d40cf0c 100644 --- a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_communicator_config.py +++ b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_communicator_config.py @@ -12,3 +12,4 @@ class PeerCommunicatorConfig: forward_register_peer_config: ForwardRegisterPeerConfig = ForwardRegisterPeerConfig() poll_timeout_in_ms: int = 200 send_socket_linger_time_in_ms: int = 100 + close_timeout_in_ms: int = 100000 diff --git a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_is_ready_sender.py b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_is_ready_sender.py index bc0382a4..bc167eab 100644 --- a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_is_ready_sender.py +++ b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/peer_is_ready_sender.py @@ -116,6 +116,9 @@ def _send_peer_is_ready_to_frontend(self): serialized_message = serialize_message(message) self._out_control_socket.send(serialized_message) + def is_ready_to_stop(self): + return _States.FINISHED in self._states + class PeerIsReadySenderFactory: def create(self, diff --git a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/register_peer_sender.py b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/register_peer_sender.py index 6a5841f8..bca61e1e 100644 --- a/exasol_advanced_analytics_framework/udf_communication/peer_communicator/register_peer_sender.py +++ b/exasol_advanced_analytics_framework/udf_communication/peer_communicator/register_peer_sender.py @@ -21,8 +21,8 @@ def __init__(self, timer: Timer): self._needs_to_send_for_peer = needs_to_send_for_peer self._register_peer_connection = register_peer_connection - if self._needs_to_send_for_peer and self._register_peer_connection is None: - raise ValueError("_register_peer_connection is None while _needs_to_send_for_peer is true") + if needs_to_send_for_peer and self._register_peer_connection is None: + raise ValueError("_register_peer_connection is None while needs_to_send_for_peer is true") self._my_connection_info = my_connection_info self._timer = timer self._finished = False @@ -56,6 +56,9 @@ def _should_we_send(self): result = is_time and not self._finished return result + def is_ready_to_stop(self): + return (self._finished and self._needs_to_send_for_peer) or not self._needs_to_send_for_peer + class RegisterPeerSenderFactory(): def create(self, diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_close.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_close.py new file mode 100644 index 00000000..817f2be5 --- /dev/null +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_close.py @@ -0,0 +1,157 @@ +import sys +import time +import traceback +from pathlib import Path +from typing import Dict, List + +import pytest +import structlog +import zmq +from numpy.random import RandomState +from structlog import WriteLoggerFactory +from structlog.tracebacks import ExceptionDictTransformer +from structlog.types import FilteringBoundLogger + +from exasol_advanced_analytics_framework.udf_communication.connection_info import ConnectionInfo +from exasol_advanced_analytics_framework.udf_communication.ip_address import IPAddress +from exasol_advanced_analytics_framework.udf_communication.peer_communicator import PeerCommunicator +from exasol_advanced_analytics_framework.udf_communication.peer_communicator.forward_register_peer_config import \ + ForwardRegisterPeerConfig +from exasol_advanced_analytics_framework.udf_communication.peer_communicator.peer_communicator_config import \ + PeerCommunicatorConfig +from exasol_advanced_analytics_framework.udf_communication.socket_factory.fault_injection import \ + FaultInjectionSocketFactory +from exasol_advanced_analytics_framework.udf_communication.socket_factory.zmq_wrapper import ZMQSocketFactory +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ + ConditionalMethodDropper +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import \ + PeerCommunicatorTestProcessParameter, BidirectionalQueue, TestProcess, assert_processes_finish + +structlog.configure( + context_class=dict, + logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + processors=[ + structlog.contextvars.merge_contextvars, + ConditionalMethodDropper(method_name="debug"), + structlog.processors.add_log_level, + structlog.processors.TimeStamper(), + structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.CallsiteParameterAdder(), + structlog.processors.JSONRenderer() + ] +) + +LOGGER: FilteringBoundLogger = structlog.get_logger() + + +def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQueue): + logger = LOGGER.bind(group_identifier=parameter.group_identifier, name=parameter.instance_name) + try: + listen_ip = IPAddress(ip_address=f"127.1.0.1") + context = zmq.Context() + socket_factory = ZMQSocketFactory(context) + socket_factory = FaultInjectionSocketFactory(socket_factory, 0.01, RandomState(parameter.seed)) + com = PeerCommunicator( + name=parameter.instance_name, + number_of_peers=parameter.number_of_instances, + listen_ip=listen_ip, + group_identifier=parameter.group_identifier, + socket_factory=socket_factory, + config=PeerCommunicatorConfig( + forward_register_peer_config=ForwardRegisterPeerConfig( + is_leader=False, + is_enabled=False + ) + ), + ) + try: + queue.put(com.my_connection_info) + peer_connection_infos = queue.get() + for index, connection_info in peer_connection_infos.items(): + com.register_peer(connection_info) + finally: + try: + com.stop() + queue.put("Success") + except: + logger.exception("Exception during stop") + queue.put("Failed") + context.destroy(linger=0) + for frame in sys._current_frames().values(): + stacktrace = traceback.format_stack(frame) + logger.info("Frame", stacktrace=stacktrace) + except Exception as e: + queue.put("Failed") + logger.exception("Exception during test") + + +@pytest.mark.parametrize("number_of_instances, repetitions", [(2, 1000), (10, 100), (50, 10)]) +def test_reliability(number_of_instances: int, repetitions: int): + run_test_with_repetitions(number_of_instances, repetitions) + + +REPETITIONS_FOR_FUNCTIONALITY = 1 + + +def test_functionality_2(): + run_test_with_repetitions(2, REPETITIONS_FOR_FUNCTIONALITY) + + +def test_functionality_3(): + run_test_with_repetitions(3, REPETITIONS_FOR_FUNCTIONALITY) + + +def test_functionality_10(): + run_test_with_repetitions(10, REPETITIONS_FOR_FUNCTIONALITY) + + +def test_functionality_25(): + run_test_with_repetitions(25, REPETITIONS_FOR_FUNCTIONALITY) + + +def test_functionality_50(): + run_test_with_repetitions(50, REPETITIONS_FOR_FUNCTIONALITY) + + +def run_test_with_repetitions(number_of_instances: int, repetitions: int): + for i in range(repetitions): + LOGGER.info(f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances) + start_time = time.monotonic() + group = f"{time.monotonic_ns()}" + expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances, seed=i) + assert expected_peers_of_threads == peers_of_threads + end_time = time.monotonic() + LOGGER.info(f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + duration=end_time - start_time) + + +def run_test(group: str, number_of_instances: int, seed: int): + connection_infos: Dict[int, ConnectionInfo] = {} + parameters = [ + PeerCommunicatorTestProcessParameter( + instance_name=f"i{i}", group_identifier=group, + number_of_instances=number_of_instances, + seed=seed + i) + for i in range(number_of_instances)] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ + [TestProcess(parameter, run=run) for parameter in parameters] + for i in range(number_of_instances): + processes[i].start() + connection_infos[i] = processes[i].get() + for i in range(number_of_instances): + t = processes[i].put(connection_infos) + assert_processes_finish(processes, timeout_in_seconds=180) + result_of_threads: Dict[int, List[ConnectionInfo]] = {} + for i in range(number_of_instances): + result_of_threads[i] = processes[i].get() + expected_results_of_threads = { + i: "Success" + for i in range(number_of_instances) + } + return expected_results_of_threads, result_of_threads diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_close.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_close.py new file mode 100644 index 00000000..cfa95413 --- /dev/null +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_close.py @@ -0,0 +1,132 @@ +import sys +import time +import traceback +from pathlib import Path +from typing import Dict, List + +import structlog +import zmq +from structlog import WriteLoggerFactory +from structlog.tracebacks import ExceptionDictTransformer +from structlog.types import FilteringBoundLogger + +from exasol_advanced_analytics_framework.udf_communication.connection_info import ConnectionInfo +from exasol_advanced_analytics_framework.udf_communication.ip_address import IPAddress +from exasol_advanced_analytics_framework.udf_communication.peer_communicator import PeerCommunicator +from exasol_advanced_analytics_framework.udf_communication.peer_communicator.forward_register_peer_config import \ + ForwardRegisterPeerConfig +from exasol_advanced_analytics_framework.udf_communication.peer_communicator.peer_communicator_config import \ + PeerCommunicatorConfig +from exasol_advanced_analytics_framework.udf_communication.socket_factory.zmq_wrapper import ZMQSocketFactory +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ + ConditionalMethodDropper +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import \ + PeerCommunicatorTestProcessParameter, TestProcess, assert_processes_finish, BidirectionalQueue + +structlog.configure( + context_class=dict, + logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + processors=[ + structlog.contextvars.merge_contextvars, + ConditionalMethodDropper(method_name="debug"), + structlog.processors.add_log_level, + structlog.processors.TimeStamper(), + structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.CallsiteParameterAdder(), + structlog.processors.JSONRenderer() + ] +) + +LOGGER: FilteringBoundLogger = structlog.get_logger() + + +def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQueue): + logger = LOGGER.bind(group_identifier=parameter.group_identifier, name=parameter.instance_name) + try: + listen_ip = IPAddress(ip_address=f"127.1.0.1") + context = zmq.Context() + socket_factory = ZMQSocketFactory(context) + com = PeerCommunicator( + name=parameter.instance_name, + number_of_peers=parameter.number_of_instances, + listen_ip=listen_ip, + group_identifier=parameter.group_identifier, + socket_factory=socket_factory, + config=PeerCommunicatorConfig( + forward_register_peer_config=ForwardRegisterPeerConfig( + is_leader=False, + is_enabled=False + ) + ), + ) + try: + queue.put(com.my_connection_info) + peer_connection_infos = queue.get() + for index, connection_info in peer_connection_infos.items(): + com.register_peer(connection_info) + time.sleep(150) + finally: + try: + com.stop() + queue.put("Success") + except Exception as e: + logger.exception("Exception during test") + queue.put("Failed") + context.destroy(linger=0) + for frame in sys._current_frames().values(): + stacktrace = traceback.format_stack(frame) + logger.info("Frame", stacktrace=stacktrace) + except Exception as e: + logger.exception("Exception during test") + queue.put("Failed") + + +REPETITIONS_FOR_FUNCTIONALITY = 1 + + +def test_functionality_2(): + run_test_with_repetitions(2, REPETITIONS_FOR_FUNCTIONALITY) + + +def run_test_with_repetitions(number_of_instances: int, repetitions: int): + for i in range(repetitions): + LOGGER.info(f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances) + start_time = time.monotonic() + group = f"{time.monotonic_ns()}" + expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances, seed=i) + assert expected_peers_of_threads == peers_of_threads + end_time = time.monotonic() + LOGGER.info(f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + duration=end_time - start_time) + + +def run_test(group: str, number_of_instances: int, seed: int): + connection_infos: Dict[int, ConnectionInfo] = {} + parameters = [ + PeerCommunicatorTestProcessParameter( + instance_name=f"i{i}", group_identifier=group, + number_of_instances=number_of_instances, + seed=seed + i) + for i in range(number_of_instances)] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ + [TestProcess(parameter, run=run) for parameter in parameters] + for i in range(number_of_instances): + processes[i].start() + connection_infos[i] = processes[i].get() + for i in range(number_of_instances): + t = processes[i].put(connection_infos) + assert_processes_finish(processes, timeout_in_seconds=180) + result_of_threads: Dict[int, List[ConnectionInfo]] = {} + for i in range(number_of_instances): + result_of_threads[i] = processes[i].get() + expected_results_of_threads = { + i: "Success" + for i in range(number_of_instances) + } + return expected_results_of_threads, result_of_threads diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_for_peers.py similarity index 95% rename from tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer.py rename to tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_for_peers.py index 04c7eed4..fa9128ca 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_and_wait_for_peers.py @@ -19,14 +19,15 @@ from exasol_advanced_analytics_framework.udf_communication.peer_communicator.forward_register_peer_config import \ ForwardRegisterPeerConfig from exasol_advanced_analytics_framework.udf_communication.peer_communicator.peer_communicator import key_for_peer -from exasol_advanced_analytics_framework.udf_communication.socket_factory.fault_injection import \ - FaultInjectionSocketFactory from exasol_advanced_analytics_framework.udf_communication.peer_communicator.peer_communicator_config import \ PeerCommunicatorConfig +from exasol_advanced_analytics_framework.udf_communication.socket_factory.fault_injection import \ + FaultInjectionSocketFactory from exasol_advanced_analytics_framework.udf_communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, BidirectionalQueue, assert_processes_finish, \ - PeerCommunicatorTestProcessParameter +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ + ConditionalMethodDropper +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import \ + PeerCommunicatorTestProcessParameter, BidirectionalQueue, TestProcess, assert_processes_finish structlog.configure( context_class=dict, @@ -75,7 +76,7 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue logger.info("peers", number_of_peers=len(peers)) queue.put(peers) finally: - com.close() + com.stop() logger.info("after close") context.destroy(linger=0) logger.info("after destroy") @@ -99,6 +100,10 @@ def test_functionality_2(): run_test_with_repetitions(2, REPETITIONS_FOR_FUNCTIONALITY) +def test_functionality_3(): + run_test_with_repetitions(3, REPETITIONS_FOR_FUNCTIONALITY) + + def test_functionality_10(): run_test_with_repetitions(10, REPETITIONS_FOR_FUNCTIONALITY) diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_close.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_close.py new file mode 100644 index 00000000..dcc02d94 --- /dev/null +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_close.py @@ -0,0 +1,157 @@ +import sys +import time +import traceback +from pathlib import Path +from typing import Dict, List + +import pytest +import structlog +import zmq +from numpy.random import RandomState +from structlog import WriteLoggerFactory +from structlog.tracebacks import ExceptionDictTransformer +from structlog.types import FilteringBoundLogger + +from exasol_advanced_analytics_framework.udf_communication.connection_info import ConnectionInfo +from exasol_advanced_analytics_framework.udf_communication.ip_address import IPAddress +from exasol_advanced_analytics_framework.udf_communication.peer_communicator import PeerCommunicator +from exasol_advanced_analytics_framework.udf_communication.peer_communicator.forward_register_peer_config import \ + ForwardRegisterPeerConfig +from exasol_advanced_analytics_framework.udf_communication.peer_communicator.peer_communicator_config import \ + PeerCommunicatorConfig +from exasol_advanced_analytics_framework.udf_communication.socket_factory.fault_injection import \ + FaultInjectionSocketFactory +from exasol_advanced_analytics_framework.udf_communication.socket_factory.zmq_wrapper import ZMQSocketFactory +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ + ConditionalMethodDropper +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import \ + PeerCommunicatorTestProcessParameter, BidirectionalQueue, TestProcess, assert_processes_finish + +structlog.configure( + context_class=dict, + logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), + processors=[ + structlog.contextvars.merge_contextvars, + ConditionalMethodDropper(method_name="debug"), + structlog.processors.add_log_level, + structlog.processors.TimeStamper(), + structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), + structlog.processors.CallsiteParameterAdder(), + structlog.processors.JSONRenderer() + ] +) + +LOGGER: FilteringBoundLogger = structlog.get_logger() + + +def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQueue): + logger = LOGGER.bind(group_identifier=parameter.group_identifier, name=parameter.instance_name) + try: + listen_ip = IPAddress(ip_address=f"127.1.0.1") + context = zmq.Context() + socket_factory = ZMQSocketFactory(context) + socket_factory = FaultInjectionSocketFactory(socket_factory, 0.01, RandomState(parameter.seed)) + leader = False + leader_name = "i0" + if parameter.instance_name == leader_name: + leader = True + com = PeerCommunicator( + name=parameter.instance_name, + number_of_peers=parameter.number_of_instances, + listen_ip=listen_ip, + group_identifier=parameter.group_identifier, + config=PeerCommunicatorConfig( + forward_register_peer_config=ForwardRegisterPeerConfig( + is_leader=leader, + is_enabled=True + ), + ), + socket_factory=socket_factory + ) + try: + queue.put(com.my_connection_info) + peer_connection_infos = queue.get() + if parameter.instance_name == leader_name: + for index, connection_info in peer_connection_infos.items(): + com.register_peer(connection_info) + finally: + try: + com.stop() + queue.put("Success") + except: + queue.put("Failed") + context.destroy(linger=0) + for frame in sys._current_frames().values(): + stacktrace = traceback.format_stack(frame) + logger.info("Frame", stacktrace=stacktrace) + except Exception as e: + queue.put("Failed") + logger.exception("Exception during test") + + +@pytest.mark.parametrize("number_of_instances, repetitions", [(2, 1000), (10, 100), (50, 10)]) +def test_reliability(number_of_instances: int, repetitions: int): + run_test_with_repetitions(number_of_instances, repetitions) + + +REPETITIONS_FOR_FUNCTIONALITY = 3 + + +def test_functionality_2(): + run_test_with_repetitions(2, REPETITIONS_FOR_FUNCTIONALITY) + + +def test_functionality_3(): + run_test_with_repetitions(3, REPETITIONS_FOR_FUNCTIONALITY) + + +def test_functionality_10(): + run_test_with_repetitions(10, REPETITIONS_FOR_FUNCTIONALITY) + + +def test_functionality_25(): + run_test_with_repetitions(25, REPETITIONS_FOR_FUNCTIONALITY) + + +def run_test_with_repetitions(number_of_instances: int, repetitions: int): + for i in range(repetitions): + LOGGER.info(f"Start iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances) + start_time = time.monotonic() + group = f"{time.monotonic_ns()}" + expected_peers_of_threads, peers_of_threads = run_test(group, number_of_instances, seed=i) + assert expected_peers_of_threads == peers_of_threads + end_time = time.monotonic() + LOGGER.info(f"Finish iteration", + iteration=i + 1, + repetitions=repetitions, + number_of_instances=number_of_instances, + duration=end_time - start_time) + + +def run_test(group: str, number_of_instances: int, seed: int): + connection_infos: Dict[int, ConnectionInfo] = {} + parameters = [ + PeerCommunicatorTestProcessParameter( + instance_name=f"i{i}", group_identifier=group, + number_of_instances=number_of_instances, + seed=seed + i) + for i in range(number_of_instances)] + processes: List[TestProcess[PeerCommunicatorTestProcessParameter]] = \ + [TestProcess(parameter, run=run) for parameter in parameters] + for i in range(number_of_instances): + processes[i].start() + connection_infos[i] = processes[i].get() + for i in range(number_of_instances): + t = processes[i].put(connection_infos) + assert_processes_finish(processes, timeout_in_seconds=180) + result_of_threads: Dict[int, List[ConnectionInfo]] = {} + for i in range(number_of_instances): + result_of_threads[i] = processes[i].get() + expected_results_of_threads = { + i: "Success" + for i in range(number_of_instances) + } + return expected_results_of_threads, result_of_threads diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_wait_for_peers.py similarity index 93% rename from tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward.py rename to tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_wait_for_peers.py index 845d6199..3eeea31d 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_add_peer_forward_and_wait_for_peers.py @@ -23,17 +23,18 @@ from exasol_advanced_analytics_framework.udf_communication.peer_communicator.peer_communicator_config import \ PeerCommunicatorConfig from exasol_advanced_analytics_framework.udf_communication.socket_factory.zmq_wrapper import ZMQSocketFactory -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, BidirectionalQueue, assert_processes_finish, \ - PeerCommunicatorTestProcessParameter +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ + ConditionalMethodDropper +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import \ + PeerCommunicatorTestProcessParameter, BidirectionalQueue, TestProcess, assert_processes_finish structlog.configure( context_class=dict, logger_factory=WriteLoggerFactory(file=Path(__file__).with_suffix(".log").open("wt")), processors=[ structlog.contextvars.merge_contextvars, - ConditionalMethodDropper(method_name="debug"), - ConditionalMethodDropper(method_name="info"), + #ConditionalMethodDropper(method_name="debug"), + #ConditionalMethodDropper(method_name="info"), structlog.processors.add_log_level, structlog.processors.TimeStamper(fmt="ISO"), structlog.processors.ExceptionRenderer(exception_formatter=ExceptionDictTransformer(locals_max_string=320)), @@ -76,12 +77,15 @@ def run(parameter: PeerCommunicatorTestProcessParameter, queue: BidirectionalQue com.register_peer(connection_info) peers = com.peers(timeout_in_milliseconds=None) logger.info("peers", peers=len(peers)) - queue.put(peers) finally: - com.close() + logger.info("com stop before") + com.stop() + logger.info("com stop after") + queue.put(peers) except Exception as e: traceback.print_exc() logger.exception("Exception during test", exception=e) + queue.put([]) @pytest.mark.parametrize("number_of_instances, repetitions", [(2, 1000), (10, 100)]) diff --git a/tests/integration_tests/without_db/udf_communication/peer_communication/test_send_recv.py b/tests/integration_tests/without_db/udf_communication/peer_communication/test_send_recv.py index 5db422e4..6897dc42 100644 --- a/tests/integration_tests/without_db/udf_communication/peer_communication/test_send_recv.py +++ b/tests/integration_tests/without_db/udf_communication/peer_communication/test_send_recv.py @@ -18,8 +18,10 @@ from exasol_advanced_analytics_framework.udf_communication.socket_factory.zmq_wrapper import ZMQSocketFactory from exasol_advanced_analytics_framework.udf_communication.peer_communicator.peer_communicator_config import \ PeerCommunicatorConfig -from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import ConditionalMethodDropper -from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, BidirectionalQueue, assert_processes_finish, \ +from tests.integration_tests.without_db.udf_communication.peer_communication.conditional_method_dropper import \ + ConditionalMethodDropper +from tests.integration_tests.without_db.udf_communication.peer_communication.utils import TestProcess, \ + BidirectionalQueue, assert_processes_finish, \ PeerCommunicatorTestProcessParameter structlog.configure( diff --git a/tests/unit_tests/udf_communication/peer_communication/test_background_peer_state.py b/tests/unit_tests/udf_communication/peer_communication/test_background_peer_state.py index 5c5750cd..204f02eb 100644 --- a/tests/unit_tests/udf_communication/peer_communication/test_background_peer_state.py +++ b/tests/unit_tests/udf_communication/peer_communication/test_background_peer_state.py @@ -89,7 +89,7 @@ def test_init(): def test_resend_if_necessary(): test_setup = create_test_setup() test_setup.reset_mocks() - test_setup.background_peer_state.resend_if_necessary() + test_setup.background_peer_state.try_send() assert ( test_setup.connection_establisher_mock.mock_calls == [call.try_send()] and test_setup.sender_mock.mock_calls == [] @@ -162,7 +162,7 @@ def test_forward_payload(): def test_close(): test_setup = create_test_setup() test_setup.reset_mocks() - test_setup.background_peer_state.close() + test_setup.background_peer_state.stop() assert ( test_setup.connection_establisher_mock.mock_calls == [] and test_setup.sender_mock.mock_calls == [] diff --git a/tests/unit_tests/udf_framework/__init__.py b/tests/unit_tests/udf_framework/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit_tests/udf_framework/mock_query_handlers.py b/tests/unit_tests/udf_framework/mock_query_handlers.py deleted file mode 100644 index 086fbe2b..00000000 --- a/tests/unit_tests/udf_framework/mock_query_handlers.py +++ /dev/null @@ -1,169 +0,0 @@ -from typing import Dict, Any, Union - -from exasol_data_science_utils_python.schema.column import \ - Column -from exasol_data_science_utils_python.schema.column_name import \ - ColumnName -from exasol_data_science_utils_python.schema.column_type import \ - ColumnType - -from exasol_advanced_analytics_framework.query_handler.context.scope_query_handler_context import \ - ScopeQueryHandlerContext -from exasol_advanced_analytics_framework.query_handler.query.select_query import SelectQueryWithColumnDefinition, \ - SelectQuery -from exasol_advanced_analytics_framework.query_handler.result \ - import Finish, Continue -from exasol_advanced_analytics_framework.query_result.query_result \ - import QueryResult -from exasol_advanced_analytics_framework.udf_framework.udf_query_handler import UDFQueryHandler -from exasol_advanced_analytics_framework.udf_framework.udf_query_handler_factory import UDFQueryHandlerFactory - -TEST_CONNECTION = "TEST_CONNECTION" - -TEST_INPUT = "<>" -FINAL_RESULT = '<>' -QUERY_LIST = [SelectQuery("SELECT 1 FROM DUAL"), SelectQuery("SELECT 2 FROM DUAL")] - - -class MockQueryHandlerWithOneIteration(UDFQueryHandler): - - def __init__(self, parameter: str, query_handler_context: ScopeQueryHandlerContext): - super().__init__(parameter, query_handler_context) - if not isinstance(parameter, str): - raise AssertionError(f"Expected parameter={parameter} to be a string.") - if parameter != TEST_INPUT: - raise AssertionError(f"Expected parameter={parameter} to be '{TEST_INPUT}'.") - - def start(self) -> Union[Continue, Finish[str]]: - return Finish(result=FINAL_RESULT) - - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: - pass - - -class MockQueryHandlerWithOneIterationFactory(UDFQueryHandlerFactory): - - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - return MockQueryHandlerWithOneIteration(parameter, query_handler_context) - - -class MockQueryHandlerWithTwoIterations(UDFQueryHandler): - def __init__(self, - parameter: str, - query_handler_context: ScopeQueryHandlerContext): - super().__init__(parameter, query_handler_context) - self._parameter = parameter - - def start(self) -> Union[Continue, Finish[str]]: - return_query = "SELECT a, table1.b, c FROM table1, table2 " \ - "WHERE table1.b=table2.b" - return_query_columns = [ - Column(ColumnName("a"), ColumnType("INTEGER")), - Column(ColumnName("b"), ColumnType("INTEGER"))] - query_handler_return_query = SelectQueryWithColumnDefinition( - query_string=return_query, - output_columns=return_query_columns) - query_handler_result = Continue( - query_list=QUERY_LIST, - input_query=query_handler_return_query) - return query_handler_result - - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[Dict[str, Any]]]: - a = query_result.a - if a != 1: - raise AssertionError(f"Expected query_result.a={a} to be 1.") - b = query_result.b - if b != 2: - raise AssertionError(f"Expected query_result.b={b} to be 2.") - has_next = query_result.next() - if has_next: - raise AssertionError(f"No next row expected") - query_handler_result = Finish(result=FINAL_RESULT) - return query_handler_result - - -class MockQueryHandlerWithTwoIterationsFactory(UDFQueryHandlerFactory): - - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - return MockQueryHandlerWithTwoIterations(parameter, query_handler_context) - - -class QueryHandlerTestWithOneIterationAndTempTable(UDFQueryHandler): - - def __init__(self, parameter: str, query_handler_context: ScopeQueryHandlerContext): - super().__init__(parameter, query_handler_context) - - def start(self) -> Union[Continue, Finish[str]]: - self._query_handler_context.get_temporary_table_name() - return Finish(result=FINAL_RESULT) - - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: - pass - - -class QueryHandlerTestWithOneIterationAndTempTableFactory(UDFQueryHandlerFactory): - - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - return QueryHandlerTestWithOneIterationAndTempTable(parameter, query_handler_context) - - -class MockQueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContext(UDFQueryHandler): - - def __init__(self, parameter: str, query_handler_context: ScopeQueryHandlerContext): - super().__init__(parameter, query_handler_context) - self.child = None - - def start(self) -> Union[Continue, Finish[str]]: - self.child = self._query_handler_context.get_child_query_handler_context() - return Finish(result=FINAL_RESULT) - - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: - pass - - -class MockQueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContextFactory(UDFQueryHandlerFactory): - - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - return MockQueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContext(parameter, query_handler_context) - - -class MockQueryHandlerWithOneIterationWithNotReleasedTemporaryObject(UDFQueryHandler): - - def __init__(self, parameter: str, query_handler_context: ScopeQueryHandlerContext): - super().__init__(parameter, query_handler_context) - self.proxy = None - self.child = None - - def start(self) -> Union[Continue, Finish[str]]: - self.child = self._query_handler_context.get_child_query_handler_context() - self.proxy = self.child.get_temporary_table_name() - return Finish(result=FINAL_RESULT) - - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: - pass - - -class MockQueryHandlerWithOneIterationWithNotReleasedTemporaryObjectFactory(UDFQueryHandlerFactory): - - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - return MockQueryHandlerWithOneIterationWithNotReleasedTemporaryObject(parameter, query_handler_context) - - -class MockQueryHandlerUsingConnection(UDFQueryHandler): - - def __init__(self, parameter: str, query_handler_context: ScopeQueryHandlerContext): - super().__init__(parameter, query_handler_context) - - def start(self) -> Union[Continue, Finish[str]]: - connection = self._query_handler_context.get_connection(TEST_CONNECTION) - return Finish( - f"{connection.name},{connection.address},{connection.user},{connection.password}") - - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[str]]: - pass - - -class MockQueryHandlerUsingConnectionFactory(UDFQueryHandlerFactory): - - def create(self, parameter: str, query_handler_context: ScopeQueryHandlerContext) -> UDFQueryHandler: - return MockQueryHandlerUsingConnection(parameter, query_handler_context) diff --git a/tests/unit_tests/udf_framework/test_json_udf_query_handler.py b/tests/unit_tests/udf_framework/test_json_udf_query_handler.py deleted file mode 100644 index a5ef46e0..00000000 --- a/tests/unit_tests/udf_framework/test_json_udf_query_handler.py +++ /dev/null @@ -1,129 +0,0 @@ -import json -from json import JSONDecodeError -from pathlib import PurePosixPath -from typing import Union - -import pytest -from exasol_bucketfs_utils_python.localfs_mock_bucketfs_location import LocalFSMockBucketFSLocation -from exasol_data_science_utils_python.schema.column import Column -from exasol_data_science_utils_python.schema.column_name import ColumnName -from exasol_data_science_utils_python.schema.column_type import ColumnType - -from exasol_advanced_analytics_framework.query_handler.context.scope_query_handler_context import \ - ScopeQueryHandlerContext -from exasol_advanced_analytics_framework.query_handler.context.top_level_query_handler_context import \ - TopLevelQueryHandlerContext -from exasol_advanced_analytics_framework.query_handler.json_udf_query_handler import JSONQueryHandler, JSONType -from exasol_advanced_analytics_framework.query_handler.result import Continue, Finish -from exasol_advanced_analytics_framework.query_result.mock_query_result import MockQueryResult -from exasol_advanced_analytics_framework.query_result.query_result import QueryResult -from exasol_advanced_analytics_framework.udf_framework.json_udf_query_handler_factory import JsonUDFQueryHandler - - -@pytest.fixture() -def temporary_schema_name(): - return "temp_schema_name" - - -@pytest.fixture() -def top_level_query_handler_context(tmp_path, - temporary_schema_name, - test_connection_lookup): - top_level_query_handler_context = TopLevelQueryHandlerContext( - temporary_bucketfs_location=LocalFSMockBucketFSLocation(base_path=PurePosixPath(tmp_path) / "bucketfs"), - temporary_db_object_name_prefix="temp_db_object", - connection_lookup=test_connection_lookup, - temporary_schema_name=temporary_schema_name, - ) - return top_level_query_handler_context - - -class ConstructorTestJSONQueryHandler(JSONQueryHandler): - - def __init__(self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext): - super().__init__(parameter, query_handler_context) - - def start(self) -> Union[Continue, Finish[JSONType]]: - raise AssertionError("Should not be called") - - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[JSONType]]: - raise AssertionError("Should not be called") - - -def test_constructor_valid_json(top_level_query_handler_context): - parameter = { - "test_key": "test_value" - } - json_str_parameter = json.dumps(parameter) - query_handler = JsonUDFQueryHandler( - parameter=json_str_parameter, - query_handler_context=top_level_query_handler_context, - wrapped_json_query_handler_class=ConstructorTestJSONQueryHandler - ) - - -def test_constructor_invalid_json(top_level_query_handler_context): - with pytest.raises(JSONDecodeError): - query_handler = JsonUDFQueryHandler( - parameter="'abc'='ced'", - query_handler_context=top_level_query_handler_context, - wrapped_json_query_handler_class=ConstructorTestJSONQueryHandler - ) - - -class StartReturnParameterTestJSONQueryHandler(JSONQueryHandler): - - def __init__(self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext): - super().__init__(parameter, query_handler_context) - self._parameter = parameter - - def start(self) -> Union[Continue, Finish[JSONType]]: - return Finish[JSONType](self._parameter) - - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[JSONType]]: - raise AssertionError("Should not be called") - - -def test_start_return_parameter(top_level_query_handler_context): - parameter = { - "test_key": "test_value" - } - json_str_parameter = json.dumps(parameter) - query_handler = JsonUDFQueryHandler( - parameter=json_str_parameter, - query_handler_context=top_level_query_handler_context, - wrapped_json_query_handler_class=StartReturnParameterTestJSONQueryHandler - ) - result = query_handler.start() - assert isinstance(result, Finish) and result.result == json_str_parameter - - -class HandleQueryResultCheckQueryResultTestJSONQueryHandler(JSONQueryHandler): - - def __init__(self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext): - super().__init__(parameter, query_handler_context) - self._parameter = parameter - - def start(self) -> Union[Continue, Finish[JSONType]]: - raise AssertionError("Should not be called") - - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[JSONType]]: - a = query_result.a - return Finish[JSONType]({"a": a}) - - -def test_handle_query_result_check_query_result(top_level_query_handler_context): - parameter = { - "test_key": "test_value" - } - json_str_parameter = json.dumps(parameter) - query_handler = JsonUDFQueryHandler( - parameter=json_str_parameter, - query_handler_context=top_level_query_handler_context, - wrapped_json_query_handler_class=HandleQueryResultCheckQueryResultTestJSONQueryHandler - ) - result = query_handler.handle_query_result( - MockQueryResult(data=[(1,)], - columns=[Column(ColumnName("a"), - ColumnType("INTEGER"))])) - assert isinstance(result, Finish) and result.result == '{"a": 1}' diff --git a/tests/unit_tests/udf_framework/test_json_udf_query_handler_factory.py b/tests/unit_tests/udf_framework/test_json_udf_query_handler_factory.py deleted file mode 100644 index 6c5ed94a..00000000 --- a/tests/unit_tests/udf_framework/test_json_udf_query_handler_factory.py +++ /dev/null @@ -1,69 +0,0 @@ -import json -from pathlib import PurePosixPath -from typing import Union - -import pytest -from exasol_bucketfs_utils_python.localfs_mock_bucketfs_location import LocalFSMockBucketFSLocation -from exasol_data_science_utils_python.schema.column import Column -from exasol_data_science_utils_python.schema.column_name import ColumnName -from exasol_data_science_utils_python.schema.column_type import ColumnType - -from exasol_advanced_analytics_framework.query_handler.context.scope_query_handler_context import \ - ScopeQueryHandlerContext -from exasol_advanced_analytics_framework.query_handler.context.top_level_query_handler_context import \ - TopLevelQueryHandlerContext -from exasol_advanced_analytics_framework.query_handler.json_udf_query_handler import JSONQueryHandler, JSONType -from exasol_advanced_analytics_framework.query_handler.result import Continue, Finish -from exasol_advanced_analytics_framework.query_result.mock_query_result import MockQueryResult -from exasol_advanced_analytics_framework.query_result.query_result import QueryResult -from exasol_advanced_analytics_framework.udf_framework.json_udf_query_handler_factory import JsonUDFQueryHandlerFactory -from exasol_advanced_analytics_framework.udf_framework.udf_query_handler import UDFQueryHandler - - -@pytest.fixture() -def temporary_schema_name(): - return "temp_schema_name" - - -@pytest.fixture() -def top_level_query_handler_context(tmp_path, - temporary_schema_name, - test_connection_lookup): - top_level_query_handler_context = TopLevelQueryHandlerContext( - temporary_bucketfs_location=LocalFSMockBucketFSLocation(base_path=PurePosixPath(tmp_path) / "bucketfs"), - temporary_db_object_name_prefix="temp_db_object", - connection_lookup=test_connection_lookup, - temporary_schema_name=temporary_schema_name, - ) - return top_level_query_handler_context - - -class TestJSONQueryHandler(JSONQueryHandler): - def __init__(self, parameter: JSONType, query_handler_context: ScopeQueryHandlerContext): - super().__init__(parameter, query_handler_context) - self._parameter = parameter - - def start(self) -> Union[Continue, Finish[JSONType]]: - return Finish[JSONType](self._parameter) - - def handle_query_result(self, query_result: QueryResult) -> Union[Continue, Finish[JSONType]]: - return Finish[JSONType](self._parameter) - - -class TestJsonUDFQueryHandlerFactory(JsonUDFQueryHandlerFactory): - def __init__(self): - super().__init__(TestJSONQueryHandler) - - -def test(top_level_query_handler_context): - test_input = {"a": 1} - json_str = json.dumps(test_input) - query_handler = TestJsonUDFQueryHandlerFactory().create(json_str, top_level_query_handler_context) - start_result = query_handler.start() - handle_query_result = query_handler.handle_query_result( - MockQueryResult(data=[(1,)], - columns=[Column(ColumnName("a"), - ColumnType("INTEGER"))]) - ) - assert isinstance(query_handler, UDFQueryHandler) \ - and start_result.result == json_str and handle_query_result.result == json_str diff --git a/tests/unit_tests/udf_framework/test_query_handler_runner_udf_mock.py b/tests/unit_tests/udf_framework/test_query_handler_runner_udf_mock.py deleted file mode 100644 index aef8b069..00000000 --- a/tests/unit_tests/udf_framework/test_query_handler_runner_udf_mock.py +++ /dev/null @@ -1,289 +0,0 @@ -import re -from tempfile import TemporaryDirectory - -from exasol_bucketfs_utils_python.bucketfs_factory import BucketFSFactory -from exasol_udf_mock_python.column import Column -from exasol_udf_mock_python.connection import Connection -from exasol_udf_mock_python.group import Group -from exasol_udf_mock_python.mock_exa_environment import MockExaEnvironment -from exasol_udf_mock_python.mock_meta_data import MockMetaData -from exasol_udf_mock_python.udf_mock_executor import UDFMockExecutor - -from exasol_advanced_analytics_framework.udf_framework.query_handler_runner_udf import QueryHandlerStatus -from tests.unit_tests.udf_framework import mock_query_handlers -from tests.unit_tests.udf_framework.mock_query_handlers import TEST_CONNECTION -from tests.utils.test_utils import pytest_regex - -TEMPORARY_NAME_PREFIX = "temporary_name_prefix" - -BUCKETFS_DIRECTORY = "directory" - -BUCKETFS_CONNECTION_NAME = "bucketfs_connection" - - -def _udf_wrapper(): - from exasol_udf_mock_python.udf_context import UDFContext - from exasol_advanced_analytics_framework.udf_framework. \ - query_handler_runner_udf import QueryHandlerRunnerUDF - - udf = QueryHandlerRunnerUDF(exa) - - def run(ctx: UDFContext): - udf.run(ctx) - - -def create_mock_data(): - meta = MockMetaData( - script_code_wrapper_function=_udf_wrapper, - input_type="SET", - input_columns=[ - Column("0", int, "INTEGER"), # iter_num - Column("1", str, "VARCHAR(2000000)"), # temporary_bfs_location_conn - Column("2", str, "VARCHAR(2000000)"), # temporary_bfs_location_directory - Column("3", str, "VARCHAR(2000000)"), # temporary_name_prefix - Column("4", str, "VARCHAR(2000000)"), # temporary_schema_name - Column("5", str, "VARCHAR(2000000)"), # python_class_name - Column("6", str, "VARCHAR(2000000)"), # python_class_module - Column("7", str, "VARCHAR(2000000)"), # parameters - ], - output_type="EMITS", - output_columns=[ - Column("outputs", str, "VARCHAR(2000000)") - ], - is_variadic_input=True - ) - return meta - - -def test_query_handler_udf_with_one_iteration(): - executor = UDFMockExecutor() - meta = create_mock_data() - - with TemporaryDirectory() as path: - bucketfs_connection = Connection(address=f"file://{path}/query_handler") - exa = MockExaEnvironment( - metadata=meta, - connections={"bucketfs_connection": bucketfs_connection}) - - input_data = ( - 0, - BUCKETFS_CONNECTION_NAME, - BUCKETFS_DIRECTORY, - TEMPORARY_NAME_PREFIX, - "temp_schema", - "MockQueryHandlerWithOneIterationFactory", - "tests.unit_tests.udf_framework.mock_query_handlers", - mock_query_handlers.TEST_INPUT - ) - result = executor.run([Group([input_data])], exa) - rows = [row[0] for row in result[0].rows] - expected_rows = [None, None, QueryHandlerStatus.FINISHED.name, mock_query_handlers.FINAL_RESULT] - assert rows == expected_rows - - -def test_query_handler_udf_with_one_iteration_with_not_released_child_query_handler_context(): - executor = UDFMockExecutor() - meta = create_mock_data() - - with TemporaryDirectory() as path: - bucketfs_connection = Connection(address=f"file://{path}/query_handler") - exa = MockExaEnvironment( - metadata=meta, - connections={"bucketfs_connection": bucketfs_connection}) - - input_data = ( - 0, - BUCKETFS_CONNECTION_NAME, - BUCKETFS_DIRECTORY, - TEMPORARY_NAME_PREFIX, - "temp_schema", - "MockQueryHandlerWithOneIterationWithNotReleasedChildQueryHandlerContextFactory", - "tests.unit_tests.udf_framework.mock_query_handlers", - "{}" - ) - result = executor.run([Group([input_data])], exa) - rows = [row[0] for row in result[0].rows] - expected_rows = [None, - None, - QueryHandlerStatus.ERROR.name, - pytest_regex(r".*The following child contexts were not released:*", re.DOTALL)] - assert rows == expected_rows - - -def test_query_handler_udf_with_one_iteration_with_not_released_temporary_object(): - executor = UDFMockExecutor() - meta = create_mock_data() - - with TemporaryDirectory() as path: - bucketfs_connection = Connection(address=f"file://{path}/query_handler") - exa = MockExaEnvironment( - metadata=meta, - connections={"bucketfs_connection": bucketfs_connection}) - - input_data = ( - 0, - BUCKETFS_CONNECTION_NAME, - BUCKETFS_DIRECTORY, - TEMPORARY_NAME_PREFIX, - "temp_schema", - "MockQueryHandlerWithOneIterationWithNotReleasedTemporaryObjectFactory", - "tests.unit_tests.udf_framework.mock_query_handlers", - "{}" - ) - result = executor.run([Group([input_data])], exa) - rows = [row[0] for row in result[0].rows] - expected_rows = [None, - None, - QueryHandlerStatus.ERROR.name, - pytest_regex(r".*The following child contexts were not released.*", re.DOTALL), - 'DROP TABLE IF EXISTS "temp_schema"."temporary_name_prefix_2_1";'] - assert rows == expected_rows - - -def test_query_handler_udf_with_one_iteration_and_temp_table(): - executor = UDFMockExecutor() - meta = create_mock_data() - - with TemporaryDirectory() as path: - bucketfs_connection = Connection(address=f"file://{path}/query_handler") - exa = MockExaEnvironment( - metadata=meta, - connections={"bucketfs_connection": bucketfs_connection}) - - input_data = ( - 0, - BUCKETFS_CONNECTION_NAME, - BUCKETFS_DIRECTORY, - TEMPORARY_NAME_PREFIX, - "temp_schema", - "QueryHandlerTestWithOneIterationAndTempTableFactory", - "tests.unit_tests.udf_framework.mock_query_handlers", - "{}" - ) - result = executor.run([Group([input_data])], exa) - rows = [row[0] for row in result[0].rows] - table_cleanup_query = 'DROP TABLE IF EXISTS "temp_schema"."temporary_name_prefix_1";' - expected_rows = [None, None, QueryHandlerStatus.FINISHED.name, mock_query_handlers.FINAL_RESULT, - table_cleanup_query] - assert rows == expected_rows - - -def test_query_handler_udf_with_two_iteration(tmp_path): - executor = UDFMockExecutor() - meta = create_mock_data() - - bucketfs_connection = Connection(address=f"file://{tmp_path}/query_handler") - exa = MockExaEnvironment( - metadata=meta, - connections={BUCKETFS_CONNECTION_NAME: bucketfs_connection}) - - input_data = ( - 0, - BUCKETFS_CONNECTION_NAME, - BUCKETFS_DIRECTORY, - TEMPORARY_NAME_PREFIX, - "temp_schema", - "MockQueryHandlerWithTwoIterationsFactory", - "tests.unit_tests.udf_framework.mock_query_handlers", - "{}" - ) - result = executor.run([Group([input_data])], exa) - rows = [row[0] for row in result[0].rows] - expected_return_query_view = 'CREATE VIEW "temp_schema"."temporary_name_prefix_2_1" AS ' \ - 'SELECT a, table1.b, c ' \ - 'FROM table1, table2 ' \ - 'WHERE table1.b=table2.b;' - return_query = 'SELECT "TEST_SCHEMA"."AAF_QUERY_HANDLER_UDF"(' \ - '1,' \ - "'bucketfs_connection','directory','temporary_name_prefix'," \ - '"a","b") ' \ - 'FROM "temp_schema"."temporary_name_prefix_2_1";' - expected_rows = [expected_return_query_view, return_query, QueryHandlerStatus.CONTINUE.name, "{}"] + \ - [query.query_string for query in mock_query_handlers.QUERY_LIST] - assert rows == expected_rows - - prev_state_exist = _is_state_exist(0, bucketfs_connection) - current_state_exist = _is_state_exist(1, bucketfs_connection) - assert prev_state_exist == False and current_state_exist == True - - exa = MockExaEnvironment( - metadata=MockMetaData( - script_code_wrapper_function=_udf_wrapper, - input_type="SET", - input_columns=[ - Column("0", int, "INTEGER"), # iter_num - Column("1", str, "VARCHAR(2000000)"), # temporary_bfs_location_conn - Column("2", str, "VARCHAR(2000000)"), # temporary_bfs_location_directory - Column("3", str, "VARCHAR(2000000)"), # temporary_name_prefix - Column("4", int, "INTEGER"), # column a of the input query - Column("5", int, "INTEGER"), # column b of the input query - ], output_type="EMITS", - output_columns=[ - Column("outputs", str, "VARCHAR(2000000)") - ], - is_variadic_input=True), - connections={BUCKETFS_CONNECTION_NAME: bucketfs_connection}) - - input_data = ( - 1, - BUCKETFS_CONNECTION_NAME, - BUCKETFS_DIRECTORY, - TEMPORARY_NAME_PREFIX, - 1, - 2 - ) - result = executor.run([Group([input_data])], exa) - rows = [row[0] for row in result[0].rows] - cleanup_return_query_view = 'DROP VIEW IF EXISTS "temp_schema"."temporary_name_prefix_2_1";' - expected_rows = [None, None, QueryHandlerStatus.FINISHED.name, mock_query_handlers.FINAL_RESULT, - cleanup_return_query_view] - assert rows == expected_rows - - -def test_query_handler_udf_using_connection(): - executor = UDFMockExecutor() - meta = create_mock_data() - - with TemporaryDirectory() as path: - bucketfs_connection = Connection(address=f"file://{path}/query_handler") - test_connection = Connection(address=f"test_connection", - user="test_connection_user", - password="test_connection_pwd") - - exa = MockExaEnvironment( - metadata=meta, - connections={ - "bucketfs_connection": bucketfs_connection, - TEST_CONNECTION: test_connection} - ) - - input_data = ( - 0, - BUCKETFS_CONNECTION_NAME, - BUCKETFS_DIRECTORY, - TEMPORARY_NAME_PREFIX, - "temp_schema", - "MockQueryHandlerUsingConnectionFactory", - "tests.unit_tests.udf_framework.mock_query_handlers", - "{}" - ) - result = executor.run([Group([input_data])], exa) - rows = [row[0] for row in result[0].rows] - expected_rows = [ - None, None, QueryHandlerStatus.FINISHED.name, - f"{TEST_CONNECTION},{test_connection.address},{test_connection.user},{test_connection.password}" - ] - assert rows == expected_rows - - -def _is_state_exist( - iter_num: int, - model_connection: Connection) -> bool: - bucketfs_location = BucketFSFactory().create_bucketfs_location( - url=model_connection.address, - user=model_connection.user, - pwd=model_connection.password) - bucketfs_path = f"{BUCKETFS_DIRECTORY}/{TEMPORARY_NAME_PREFIX}/state/" - state_file = f"{str(iter_num)}.pkl" - files = bucketfs_location.list_files_in_bucketfs(bucketfs_path) - return state_file in files