Skip to content

Commit

Permalink
refactor: Extract the common code from client_interface.py to the bas…
Browse files Browse the repository at this point in the history
…e_client_interface.py.

Signed-off-by: Hongli Chen <[email protected]>
  • Loading branch information
Honglichenn committed Dec 1, 2023
1 parent 2f4946b commit 2766d8b
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 71 deletions.
5 changes: 3 additions & 2 deletions src/openjd/adaptor_runtime_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from .action import Action
from .client_interface import (
from .posix_client_interface import (
HTTPClientInterface,
PathMappingRule,
)

from .base_client_interface import PathMappingRule

__all__ = [
"Action",
"HTTPClientInterface",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.


from __future__ import annotations

import json as _json
import signal as _signal
import sys as _sys
from abc import ABC as _ABC
from abc import abstractmethod as _abstractmethod
from abc import ABC as _ABC

from dataclasses import dataclass as _dataclass
from functools import lru_cache as _lru_cache
from http import HTTPStatus as _HTTPStatus
Expand All @@ -18,14 +19,7 @@
List as _List,
Tuple as _Tuple,
)
from urllib.parse import urlencode as _urlencode

from .action import Action as _Action
from .connection import UnixHTTPConnection as _UnixHTTPConnection

# Set timeout to None so our requests are blocking calls with no timeout.
# See socket.settimeout
_REQUEST_TIMEOUT = None


# Based on adaptor runtime's PathMappingRule class
Expand All @@ -39,25 +33,71 @@ class PathMappingRule:
destination_os: str


class HTTPClientInterface(_ABC):
actions: _Dict[str, _Callable[..., None]]
socket_path: str
@_dataclass
class Response:
"""
A response wrapper class
"""

def __init__(self, socket_path: str) -> None:
"""When the client is created, we need the port number to connect to the server.
status: int
body: str
reason: str
length: int

Args:
socket_path (str): The path to the UNIX domain socket to use.

class BaseClientInterface(_ABC):
actions: _Dict[str, _Callable[..., None]]

def __init__(self) -> None:
"""
When the client is created, we need the port number to connect to the server.
"""
self.socket_path = socket_path
self.actions = {
"close": self.close,
}

# NOTE: The signals SIGKILL and SIGSTOP cannot be caught, blocked, or ignored.
# Reference: https://man7.org/linux/man-pages/man7/signal.7.html
# SIGTERM graceful shutdown.
_signal.signal(_signal.SIGTERM, self.graceful_shutdown)
@_abstractmethod
def _send_request(
self, method: str, request_path: str, *, query_string_params: _Dict | None = None
) -> Response:
"""
Send a request to the server and return the response.
This abstract method should be implemented by subclasses to handle
sending the actual request.
Args:
method (str): The HTTP method, e.g. 'GET', 'POST'.
request_path (str): The path for the request.
query_string_params (_Dict | None, optional): Query string parameters to include
in the request. Defaults to None.
Returns:
Response: The response from the server.
"""
pass

@_abstractmethod
def close(self, args: _Dict[str, _Any] | None) -> None:
"""This is the close function which will be called to cleanup the Application.
Args:
args (_Dict[str, _Any] | None): The arguments (if any) required to perform the
cleanup.
"""
pass

@_abstractmethod
def graceful_shutdown(self, signum: int, frame: _FrameType | None) -> None:
"""This is the function when we cancel. This function is called when a SIGTERM signal is
received. This functions will need to be implemented for each application we want to
support because the clean up will be different for each application.
Args:
signum (int): The signal number.
frame (_FrameType | None): The current stack frame (None or a frame object).
"""
pass

def _request_next_action(self) -> _Tuple[int, str, _Action | None]:
"""Sending a get request to the server on the /action endpoint.
Expand All @@ -67,17 +107,12 @@ def _request_next_action(self) -> _Tuple[int, str, _Action | None]:
_Tuple[int, str, _Action | None]: Returns the status code (int), the status reason
(str), the action if one was received (_Action | None).
"""
headers = {
"Content-type": "application/json",
}
connection = _UnixHTTPConnection(self.socket_path, timeout=_REQUEST_TIMEOUT)
connection.request("GET", "/action", headers=headers)
response = connection.getresponse()
connection.close()
response = self._send_request("GET", "/action")

action = None
if response.length:
action = _Action.from_bytes(response.read())
response_body = _json.loads(response.body)
action = _Action(response_body["name"], response_body["args"])
return response.status, response.reason, action

@_lru_cache(maxsize=None)
Expand All @@ -91,22 +126,17 @@ def map_path(self, path: str) -> str:
Raises:
RuntimeError: When the client fails to get a mapped path from the server.
"""
headers = {
"Content-type": "application/json",
}
connection = _UnixHTTPConnection(self.socket_path, timeout=_REQUEST_TIMEOUT)
print(f"Requesting Path Mapping for path '{path}'.", flush=True)
connection.request("GET", "/path_mapping?" + _urlencode({"path": path}), headers=headers)
response = connection.getresponse()
connection.close()

response = self._send_request("GET", "/path_mapping", query_string_params={"path": path})

if response.status == _HTTPStatus.OK and response.length:
response_dict = _json.loads(response.read().decode())
response_dict = _json.loads(response.body)
mapped_path = response_dict.get("path")
if mapped_path is not None: # pragma: no branch # HTTP 200 guarantees a mapped path
print(f"Mapped path '{path}' to '{mapped_path}'.", flush=True)
return mapped_path
reason = response.read().decode() if response.length else ""
reason = response.body if response.length else ""
raise RuntimeError(
f"ERROR: Failed to get a mapped path for path '{path}'. "
f"Server response: Status: {int(response.status)}, Response: '{reason}'",
Expand All @@ -123,24 +153,18 @@ def path_mapping_rules(self) -> _List[PathMappingRule]:
Raises:
RuntimeError: When the client fails to get a mapped path from the server.
"""
headers = {
"Content-type": "application/json",
}
connection = _UnixHTTPConnection(self.socket_path, timeout=_REQUEST_TIMEOUT)
print("Requesting Path Mapping Rules.", flush=True)
connection.request("GET", "/path_mapping_rules", headers=headers)
response = connection.getresponse()
connection.close()
response = self._send_request("GET", "/path_mapping_rules")

if response.status != _HTTPStatus.OK or not response.length:
reason = response.read().decode() if response.length else ""
reason = response.body if response.length else ""
raise RuntimeError(
f"ERROR: Failed to get a path mapping rules. "
f"Server response: Status: {int(response.status)}, Response: '{reason}'",
)

try:
response_dict = _json.loads(response.read().decode())
response_dict = _json.loads(response.body)
except _json.JSONDecodeError as e:
raise RuntimeError(
f"Expected JSON string from /path_mapping_rules endpoint, but got error: {e}",
Expand Down Expand Up @@ -200,25 +224,3 @@ def _perform_action(self, a: _Action) -> None:
)
else:
action_func(a.args)

@_abstractmethod
def close(self, args: _Dict[str, _Any] | None) -> None: # pragma: no cover
"""This is the close function which will be called to cleanup the Application.
Args:
args (_Dict[str, _Any] | None): The arguments (if any) required to perform the
cleanup.
"""
pass

@_abstractmethod
def graceful_shutdown(self, signum: int, frame: _FrameType | None) -> None: # pragma: no cover
"""This is the function when we cancel. This function is called when a SIGTERM signal is
received. This functions will need to be implemented for each application we want to
support because the clean up will be different for each application.
Args:
signum (int): The signal number.
frame (_FrameType | None): The current stack frame (None or a frame object).
"""
pass
61 changes: 61 additions & 0 deletions src/openjd/adaptor_runtime_client/posix_client_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from __future__ import annotations

import signal as _signal
from .base_client_interface import Response as _Response
from typing import Dict as _Dict

from .base_client_interface import BaseClientInterface
from .connection import UnixHTTPConnection as _UnixHTTPConnection
from urllib.parse import urlencode as _urlencode


# Set timeout to None so our requests are blocking calls with no timeout.
# See socket.settimeout
_REQUEST_TIMEOUT = None


class HTTPClientInterface(BaseClientInterface):
socket_path: str

def __init__(self, socket_path: str) -> None:
"""When the client is created, we need the port number to connect to the server.
Args:
socket_path (str): The path to the UNIX domain socket to use.
"""
super().__init__()
self.socket_path = socket_path
# NOTE: The signals SIGKILL and SIGSTOP cannot be caught, blocked, or ignored.
# Reference: https://man7.org/linux/man-pages/man7/signal.7.html
# SIGTERM graceful shutdown.
_signal.signal(_signal.SIGTERM, self.graceful_shutdown)

def _send_request(
self, method: str, request_path: str, *, query_string_params: _Dict | None = None
) -> _Response:
"""
Send a request to the server and return the response.
Args:
method (str): The HTTP method, e.g. 'GET', 'POST'.
request_path (str): The path for the request.
query_string_params (_Dict | None, optional): Query string parameters to include in the request.
Defaults to None. In Linux, the query string parameters will be added to the URL
Returns:
Response: The response from the server.
"""
headers = {
"Content-type": "application/json",
}
connection = _UnixHTTPConnection(self.socket_path, timeout=_REQUEST_TIMEOUT)
if query_string_params:
request_path += "?" + _urlencode(query_string_params)
connection.request(method, request_path, headers=headers)
response = connection.getresponse()
connection.close()
length = response.length if response.length else 0
body = response.read().decode() if length else ""
return _Response(response.status, body, response.reason, length)

0 comments on commit 2766d8b

Please sign in to comment.