Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

strict type client modules #16223

Merged
merged 4 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/prefect/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class PrefectResponse(httpx.Response):
Provides more informative error messages.
"""

def raise_for_status(self) -> None:
def raise_for_status(self) -> Response:
"""
Raise an exception if the response contains an HTTPStatusError.

Expand All @@ -174,7 +174,7 @@ def raise_for_status(self) -> None:
raise PrefectHTTPStatusError.from_httpx_error(exc) from exc.__cause__

@classmethod
def from_httpx_response(cls: Type[Self], response: httpx.Response) -> Self:
def from_httpx_response(cls: Type[Self], response: httpx.Response) -> Response:
"""
Create a `PrefectReponse` from an `httpx.Response`.

Expand All @@ -200,10 +200,10 @@ class PrefectHttpxAsyncClient(httpx.AsyncClient):

def __init__(
self,
*args,
*args: Any,
enable_csrf_support: bool = False,
raise_on_all_errors: bool = True,
**kwargs,
**kwargs: Any,
):
self.enable_csrf_support: bool = enable_csrf_support
self.csrf_token: Optional[str] = None
Expand All @@ -222,10 +222,10 @@ async def _send_with_retry(
self,
request: Request,
send: Callable[[Request], Awaitable[Response]],
send_args: Tuple,
send_kwargs: Dict,
send_args: Tuple[Any, ...],
send_kwargs: Dict[str, Any],
retry_codes: Set[int] = set(),
retry_exceptions: Tuple[Exception, ...] = tuple(),
retry_exceptions: Tuple[Type[Exception], ...] = tuple(),
):
"""
Send a request and retry it if it fails.
Expand Down Expand Up @@ -297,7 +297,7 @@ async def _send_with_retry(
if exc_info
else (
"Received response with retryable status code"
f" {response.status_code}. "
f" {response.status_code if response else 'unknown'}. "
)
)
+ f"Another attempt will be made in {retry_seconds}s. "
Expand All @@ -314,7 +314,7 @@ async def _send_with_retry(
# We ran out of retries, return the failed response
return response

async def send(self, request: Request, *args, **kwargs) -> Response:
async def send(self, request: Request, *args: Any, **kwargs: Any) -> Response:
"""
Send a request with automatic retry behavior for the following status codes:

Expand Down Expand Up @@ -414,10 +414,10 @@ class PrefectHttpxSyncClient(httpx.Client):

def __init__(
self,
*args,
*args: Any,
enable_csrf_support: bool = False,
raise_on_all_errors: bool = True,
**kwargs,
**kwargs: Any,
):
self.enable_csrf_support: bool = enable_csrf_support
self.csrf_token: Optional[str] = None
Expand All @@ -436,10 +436,10 @@ def _send_with_retry(
self,
request: Request,
send: Callable[[Request], Response],
send_args: Tuple,
send_kwargs: Dict,
send_args: Tuple[Any, ...],
send_kwargs: Dict[str, Any],
retry_codes: Set[int] = set(),
retry_exceptions: Tuple[Exception, ...] = tuple(),
retry_exceptions: Tuple[Type[Exception], ...] = tuple(),
):
"""
Send a request and retry it if it fails.
Expand Down Expand Up @@ -511,7 +511,7 @@ def _send_with_retry(
if exc_info
else (
"Received response with retryable status code"
f" {response.status_code}. "
f" {response.status_code if response else 'unknown'}. "
)
)
+ f"Another attempt will be made in {retry_seconds}s. "
Expand All @@ -528,7 +528,7 @@ def _send_with_retry(
# We ran out of retries, return the failed response
return response

def send(self, request: Request, *args, **kwargs) -> Response:
def send(self, request: Request, *args: Any, **kwargs: Any) -> Response:
"""
Send a request with automatic retry behavior for the following status codes:

Expand Down
11 changes: 7 additions & 4 deletions src/prefect/client/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
def get_cloud_client(
host: Optional[str] = None,
api_key: Optional[str] = None,
httpx_settings: Optional[dict] = None,
httpx_settings: Optional[Dict[str, Any]] = None,
infer_cloud_url: bool = False,
) -> "CloudClient":
"""
Expand All @@ -45,6 +45,9 @@ def get_cloud_client(
configured_url = prefect.settings.PREFECT_API_URL.value()
host = re.sub(PARSE_API_URL_REGEX, "", configured_url)

if host is None:
raise ValueError("Host was not provided and could not be inferred")

return CloudClient(
host=host,
api_key=api_key or PREFECT_API_KEY.value(),
Expand Down Expand Up @@ -176,7 +179,7 @@ async def __aenter__(self):
await self._client.__aenter__()
return self

async def __aexit__(self, *exc_info):
async def __aexit__(self, *exc_info: Any) -> None:
return await self._client.__aexit__(*exc_info)

def __enter__(self):
Expand All @@ -188,10 +191,10 @@ def __enter__(self):
def __exit__(self, *_):
assert False, "This should never be called but must be defined for __enter__"

async def get(self, route, **kwargs):
async def get(self, route: str, **kwargs: Any) -> Any:
return await self.request("GET", route, **kwargs)

async def request(self, method, route, **kwargs):
async def request(self, method: str, route: str, **kwargs: Any) -> Any:
try:
res = await self._client.request(method, route, **kwargs)
res.raise_for_status()
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/client/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ async def read_worker_metadata(self) -> Dict[str, Any]:
async def __aenter__(self) -> "CollectionsMetadataClient":
...

async def __aexit__(self, *exc_info) -> Any:
async def __aexit__(self, *exc_info: Any) -> Any:
...


def get_collections_metadata_client(
httpx_settings: Optional[Dict] = None,
httpx_settings: Optional[Dict[str, Any]] = None,
) -> "CollectionsMetadataClient":
"""
Creates a client that can be used to fetch metadata for
Expand Down
26 changes: 19 additions & 7 deletions src/prefect/client/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,33 @@ def __init__(
):
self.model = model
self.client_id = client_id
base_url = base_url.replace("http", "ws", 1)
base_url = base_url.replace("http", "ws", 1) if base_url else None
self.subscription_url = f"{base_url}{path}"

self.keys = list(keys)

self._connect = websockets.connect(
self.subscription_url,
subprotocols=["prefect"],
subprotocols=[websockets.Subprotocol("prefect")],
)
self._websocket = None

def __aiter__(self) -> Self:
return self

@property
def websocket(self) -> websockets.WebSocketClientProtocol:
if not self._websocket:
raise RuntimeError("Subscription is not connected")
return self._websocket

async def __anext__(self) -> S:
while True:
try:
await self._ensure_connected()
message = await self._websocket.recv()
message = await self.websocket.recv()

await self._websocket.send(orjson.dumps({"type": "ack"}).decode())
await self.websocket.send(orjson.dumps({"type": "ack"}).decode())

return self.model.model_validate_json(message)
except (
Expand Down Expand Up @@ -84,13 +90,19 @@ async def _ensure_connected(self):
AssertionError,
websockets.exceptions.ConnectionClosedError,
) as e:
if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION:
if isinstance(e, AssertionError) or (
e.rcvd and e.rcvd.code == WS_1008_POLICY_VIOLATION
):
if isinstance(e, AssertionError):
reason = e.args[0]
elif isinstance(e, websockets.exceptions.ConnectionClosedError):
elif e.rcvd and e.rcvd.reason:
reason = e.rcvd.reason
else:
reason = "unknown"
else:
reason = None

if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION:
if reason:
raise Exception(
"Unable to authenticate to the subscription. Please "
"ensure the provided `PREFECT_API_KEY` you are using is "
Expand Down
19 changes: 11 additions & 8 deletions src/prefect/client/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@
Optional,
Tuple,
TypeVar,
Union,
cast,
)

from typing_extensions import Concatenate, ParamSpec

if TYPE_CHECKING:
from prefect.client.orchestration import PrefectClient
from prefect.client.orchestration import PrefectClient, SyncPrefectClient

P = ParamSpec("P")
R = TypeVar("R")


def get_or_create_client(
client: Optional["PrefectClient"] = None,
) -> Tuple["PrefectClient", bool]:
) -> Tuple[Union["PrefectClient", "SyncPrefectClient"], bool]:
"""
Returns provided client, infers a client from context if available, or creates a new client.

Expand All @@ -48,7 +49,7 @@ def get_or_create_client(
flow_run_context = FlowRunContext.get()
task_run_context = TaskRunContext.get()

if async_client_context and async_client_context.client._loop == get_running_loop():
if async_client_context and async_client_context.client._loop == get_running_loop(): # type: ignore[reportPrivateUsage]
return async_client_context.client, True
elif (
flow_run_context
Expand All @@ -72,7 +73,7 @@ def client_injector(
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
client, _ = get_or_create_client()
return await func(client, *args, **kwargs)
return await func(cast("PrefectClient", client), *args, **kwargs)

return wrapper

Expand All @@ -90,16 +91,18 @@ def inject_client(

@wraps(fn)
async def with_injected_client(*args: P.args, **kwargs: P.kwargs) -> R:
client = cast(Optional["PrefectClient"], kwargs.pop("client", None))
client, inferred = get_or_create_client(client)
client, inferred = get_or_create_client(
cast(Optional["PrefectClient"], kwargs.pop("client", None))
)
_client = cast("PrefectClient", client)
if not inferred:
context = client
context = _client
else:
from prefect.utilities.asyncutils import asyncnullcontext

context = asyncnullcontext()
async with context as new_client:
kwargs.setdefault("client", new_client or client)
kwargs.setdefault("client", new_client or _client)
return await fn(*args, **kwargs)

return with_injected_client
5 changes: 4 additions & 1 deletion src/prefect/utilities/asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from functools import partial, wraps
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Coroutine,
Expand Down Expand Up @@ -410,7 +411,9 @@ async def ctx_call():


@asynccontextmanager
async def asyncnullcontext(value=None, *args, **kwargs):
async def asyncnullcontext(
value: Optional[Any] = None, *args: Any, **kwargs: Any
) -> AsyncGenerator[Any, None]:
yield value


Expand Down
14 changes: 9 additions & 5 deletions src/prefect/utilities/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import random


def poisson_interval(average_interval, lower=0, upper=1):
def poisson_interval(
average_interval: float, lower: float = 0, upper: float = 1
) -> float:
"""
Generates an "inter-arrival time" for a Poisson process.

Expand All @@ -16,12 +18,12 @@ def poisson_interval(average_interval, lower=0, upper=1):
return -math.log(max(1 - random.uniform(lower, upper), 1e-10)) * average_interval


def exponential_cdf(x, average_interval):
def exponential_cdf(x: float, average_interval: float) -> float:
ld = 1 / average_interval
return 1 - math.exp(-ld * x)


def lower_clamp_multiple(k):
def lower_clamp_multiple(k: float) -> float:
"""
Computes a lower clamp multiple that can be used to bound a random variate drawn
from an exponential distribution.
Expand All @@ -38,7 +40,9 @@ def lower_clamp_multiple(k):
return math.log(max(2**k / (2**k - 1), 1e-10), 2)


def clamped_poisson_interval(average_interval, clamping_factor=0.3):
def clamped_poisson_interval(
average_interval: float, clamping_factor: float = 0.3
) -> float:
"""
Bounds Poisson "inter-arrival times" to a range defined by the clamping factor.

Expand All @@ -57,7 +61,7 @@ def clamped_poisson_interval(average_interval, clamping_factor=0.3):
return poisson_interval(average_interval, lower_rv, upper_rv)


def bounded_poisson_interval(lower_bound, upper_bound):
def bounded_poisson_interval(lower_bound: float, upper_bound: float) -> float:
"""
Bounds Poisson "inter-arrival times" to a range.

Expand Down
Loading