diff --git a/src/openjd/adaptor_runtime_client/__init__.py b/src/openjd/adaptor_runtime_client/__init__.py index a85384b..1253490 100644 --- a/src/openjd/adaptor_runtime_client/__init__.py +++ b/src/openjd/adaptor_runtime_client/__init__.py @@ -1,11 +1,12 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from .action import Action -from .client_interface import ( +from .posix_client_interface import ( HTTPClientInterface, - PathMappingRule, ) +from .base_client_interface import PathMappingRule + __all__ = [ "Action", "HTTPClientInterface", diff --git a/src/openjd/adaptor_runtime_client/client_interface.py b/src/openjd/adaptor_runtime_client/base_client_interface.py similarity index 72% rename from src/openjd/adaptor_runtime_client/client_interface.py rename to src/openjd/adaptor_runtime_client/base_client_interface.py index 1c28ecf..8d32e6f 100644 --- a/src/openjd/adaptor_runtime_client/client_interface.py +++ b/src/openjd/adaptor_runtime_client/base_client_interface.py @@ -1,12 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + from __future__ import annotations import json as _json -import signal as _signal import sys as _sys -from abc import ABC as _ABC from abc import abstractmethod as _abstractmethod +from abc import ABC as _ABC + from dataclasses import dataclass as _dataclass from functools import lru_cache as _lru_cache from http import HTTPStatus as _HTTPStatus @@ -18,14 +19,7 @@ List as _List, Tuple as _Tuple, ) -from urllib.parse import urlencode as _urlencode - from .action import Action as _Action -from .connection import UnixHTTPConnection as _UnixHTTPConnection - -# Set timeout to None so our requests are blocking calls with no timeout. -# See socket.settimeout -_REQUEST_TIMEOUT = None # Based on adaptor runtime's PathMappingRule class @@ -39,25 +33,71 @@ class PathMappingRule: destination_os: str -class HTTPClientInterface(_ABC): - actions: _Dict[str, _Callable[..., None]] - socket_path: str +@_dataclass +class Response: + """ + A response wrapper class + """ - def __init__(self, socket_path: str) -> None: - """When the client is created, we need the port number to connect to the server. + status: int + body: str + reason: str + length: int - Args: - socket_path (str): The path to the UNIX domain socket to use. + +class BaseClientInterface(_ABC): + actions: _Dict[str, _Callable[..., None]] + + def __init__(self) -> None: + """ + When the client is created, we need the port number to connect to the server. """ - self.socket_path = socket_path self.actions = { "close": self.close, } - # NOTE: The signals SIGKILL and SIGSTOP cannot be caught, blocked, or ignored. - # Reference: https://man7.org/linux/man-pages/man7/signal.7.html - # SIGTERM graceful shutdown. - _signal.signal(_signal.SIGTERM, self.graceful_shutdown) + @_abstractmethod + def _send_request( + self, method: str, request_path: str, *, query_string_params: _Dict | None = None + ) -> Response: + """ + Send a request to the server and return the response. + + This abstract method should be implemented by subclasses to handle + sending the actual request. + + Args: + method (str): The HTTP method, e.g. 'GET', 'POST'. + request_path (str): The path for the request. + query_string_params (_Dict | None, optional): Query string parameters to include + in the request. Defaults to None. + + Returns: + Response: The response from the server. + """ + pass + + @_abstractmethod + def close(self, args: _Dict[str, _Any] | None) -> None: + """This is the close function which will be called to cleanup the Application. + + Args: + args (_Dict[str, _Any] | None): The arguments (if any) required to perform the + cleanup. + """ + pass + + @_abstractmethod + def graceful_shutdown(self, signum: int, frame: _FrameType | None) -> None: + """This is the function when we cancel. This function is called when a SIGTERM signal is + received. This functions will need to be implemented for each application we want to + support because the clean up will be different for each application. + + Args: + signum (int): The signal number. + frame (_FrameType | None): The current stack frame (None or a frame object). + """ + pass def _request_next_action(self) -> _Tuple[int, str, _Action | None]: """Sending a get request to the server on the /action endpoint. @@ -67,17 +107,12 @@ def _request_next_action(self) -> _Tuple[int, str, _Action | None]: _Tuple[int, str, _Action | None]: Returns the status code (int), the status reason (str), the action if one was received (_Action | None). """ - headers = { - "Content-type": "application/json", - } - connection = _UnixHTTPConnection(self.socket_path, timeout=_REQUEST_TIMEOUT) - connection.request("GET", "/action", headers=headers) - response = connection.getresponse() - connection.close() + response = self._send_request("GET", "/action") action = None if response.length: - action = _Action.from_bytes(response.read()) + response_body = _json.loads(response.body) + action = _Action(response_body["name"], response_body["args"]) return response.status, response.reason, action @_lru_cache(maxsize=None) @@ -91,22 +126,17 @@ def map_path(self, path: str) -> str: Raises: RuntimeError: When the client fails to get a mapped path from the server. """ - headers = { - "Content-type": "application/json", - } - connection = _UnixHTTPConnection(self.socket_path, timeout=_REQUEST_TIMEOUT) print(f"Requesting Path Mapping for path '{path}'.", flush=True) - connection.request("GET", "/path_mapping?" + _urlencode({"path": path}), headers=headers) - response = connection.getresponse() - connection.close() + + response = self._send_request("GET", "/path_mapping", query_string_params={"path": path}) if response.status == _HTTPStatus.OK and response.length: - response_dict = _json.loads(response.read().decode()) + response_dict = _json.loads(response.body) mapped_path = response_dict.get("path") if mapped_path is not None: # pragma: no branch # HTTP 200 guarantees a mapped path print(f"Mapped path '{path}' to '{mapped_path}'.", flush=True) return mapped_path - reason = response.read().decode() if response.length else "" + reason = response.body if response.length else "" raise RuntimeError( f"ERROR: Failed to get a mapped path for path '{path}'. " f"Server response: Status: {int(response.status)}, Response: '{reason}'", @@ -123,24 +153,18 @@ def path_mapping_rules(self) -> _List[PathMappingRule]: Raises: RuntimeError: When the client fails to get a mapped path from the server. """ - headers = { - "Content-type": "application/json", - } - connection = _UnixHTTPConnection(self.socket_path, timeout=_REQUEST_TIMEOUT) print("Requesting Path Mapping Rules.", flush=True) - connection.request("GET", "/path_mapping_rules", headers=headers) - response = connection.getresponse() - connection.close() + response = self._send_request("GET", "/path_mapping_rules") if response.status != _HTTPStatus.OK or not response.length: - reason = response.read().decode() if response.length else "" + reason = response.body if response.length else "" raise RuntimeError( f"ERROR: Failed to get a path mapping rules. " f"Server response: Status: {int(response.status)}, Response: '{reason}'", ) try: - response_dict = _json.loads(response.read().decode()) + response_dict = _json.loads(response.body) except _json.JSONDecodeError as e: raise RuntimeError( f"Expected JSON string from /path_mapping_rules endpoint, but got error: {e}", @@ -200,25 +224,3 @@ def _perform_action(self, a: _Action) -> None: ) else: action_func(a.args) - - @_abstractmethod - def close(self, args: _Dict[str, _Any] | None) -> None: # pragma: no cover - """This is the close function which will be called to cleanup the Application. - - Args: - args (_Dict[str, _Any] | None): The arguments (if any) required to perform the - cleanup. - """ - pass - - @_abstractmethod - def graceful_shutdown(self, signum: int, frame: _FrameType | None) -> None: # pragma: no cover - """This is the function when we cancel. This function is called when a SIGTERM signal is - received. This functions will need to be implemented for each application we want to - support because the clean up will be different for each application. - - Args: - signum (int): The signal number. - frame (_FrameType | None): The current stack frame (None or a frame object). - """ - pass diff --git a/src/openjd/adaptor_runtime_client/posix_client_interface.py b/src/openjd/adaptor_runtime_client/posix_client_interface.py new file mode 100644 index 0000000..0cac2d2 --- /dev/null +++ b/src/openjd/adaptor_runtime_client/posix_client_interface.py @@ -0,0 +1,61 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import signal as _signal +from .base_client_interface import Response as _Response +from typing import Dict as _Dict + +from .base_client_interface import BaseClientInterface +from .connection import UnixHTTPConnection as _UnixHTTPConnection +from urllib.parse import urlencode as _urlencode + + +# Set timeout to None so our requests are blocking calls with no timeout. +# See socket.settimeout +_REQUEST_TIMEOUT = None + + +class HTTPClientInterface(BaseClientInterface): + socket_path: str + + def __init__(self, socket_path: str) -> None: + """When the client is created, we need the port number to connect to the server. + + Args: + socket_path (str): The path to the UNIX domain socket to use. + """ + super().__init__() + self.socket_path = socket_path + # NOTE: The signals SIGKILL and SIGSTOP cannot be caught, blocked, or ignored. + # Reference: https://man7.org/linux/man-pages/man7/signal.7.html + # SIGTERM graceful shutdown. + _signal.signal(_signal.SIGTERM, self.graceful_shutdown) + + def _send_request( + self, method: str, request_path: str, *, query_string_params: _Dict | None = None + ) -> _Response: + """ + Send a request to the server and return the response. + + Args: + method (str): The HTTP method, e.g. 'GET', 'POST'. + request_path (str): The path for the request. + query_string_params (_Dict | None, optional): Query string parameters to include in the request. + Defaults to None. In Linux, the query string parameters will be added to the URL + + Returns: + Response: The response from the server. + """ + headers = { + "Content-type": "application/json", + } + connection = _UnixHTTPConnection(self.socket_path, timeout=_REQUEST_TIMEOUT) + if query_string_params: + request_path += "?" + _urlencode(query_string_params) + connection.request(method, request_path, headers=headers) + response = connection.getresponse() + connection.close() + length = response.length if response.length else 0 + body = response.read().decode() if length else "" + return _Response(response.status, body, response.reason, length)