Skip to content

Commit

Permalink
Merge branch 'main' into mattt/pydantic-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt committed Aug 12, 2024
2 parents 56867b1 + e4f0efd commit d7cc128
Show file tree
Hide file tree
Showing 7 changed files with 776 additions and 627 deletions.
6 changes: 6 additions & 0 deletions python/cog/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ class PredictionResponse(PredictionBaseModel):

metrics: Optional[Dict[str, Any]] = None

# This is used to track a fatal exception that occurs during a prediction.
# "Fatal" means that we require the worker to be shut down to recover:
# regular exceptions raised during predict are handled and do not use this
# field.
_fatal_exception: Optional[BaseException] = pydantic.PrivateAttr(default=None)

@classmethod
def with_types(cls, input_type: Type[Any], output_type: Type[Any]) -> Any:
# [compat] Input is implicitly optional -- previous versions of the
Expand Down
135 changes: 82 additions & 53 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from enum import Enum, auto, unique
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union

import attrs
import structlog
import uvicorn
from fastapi import Body, FastAPI, Header, HTTPException, Path, Response
Expand All @@ -38,14 +37,15 @@
)
from ..types import PYDANTIC_V2, CogConfig
from .helpers import unwrap_pydantic_serialization_iterators
from .probes import ProbeHelper
from .runner import (
PredictionRunner,
RunnerBusyError,
SetupResult,
SetupTask,
UnknownPredictionError,
)
from .telemetry import make_trace_context, trace_context
from .worker import Worker

if TYPE_CHECKING:
from typing import ParamSpec, TypeVar # pylint: disable=import-outside-toplevel
Expand All @@ -63,11 +63,11 @@ class Health(Enum):
READY = auto()
BUSY = auto()
SETUP_FAILED = auto()
DEFUNCT = auto()


class MyState:
health: Health
setup_task: Optional[SetupTask]
setup_result: Optional[SetupResult]


Expand All @@ -87,16 +87,21 @@ def add_setup_failed_routes(
result = SetupResult(
started_at=started_at,
completed_at=datetime.now(tz=timezone.utc),
logs=msg,
logs=[msg],
status=schema.Status.FAILED,
)
app.state.setup_result = result
app.state.health = Health.SETUP_FAILED

@app.get("/health-check")
async def healthcheck_startup_failed() -> Any:
setup = attrs.asdict(app.state.setup_result)
return jsonable_encoder({"status": app.state.health.name, "setup": setup})
assert app.state.setup_result
return jsonable_encoder(
{
"status": app.state.health.name,
"setup": app.state.setup_result.to_dict(),
}
)


def create_app( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements
Expand Down Expand Up @@ -239,15 +244,14 @@ def custom_openapi() -> Dict[str, Any]:
app.openapi = custom_openapi

app.state.health = Health.STARTING
app.state.setup_task = None
app.state.setup_result = None
started_at = datetime.now(tz=timezone.utc)

# shutdown is needed no matter what happens
@app.post("/shutdown")
async def start_shutdown() -> Any:
log.info("shutdown requested via http")
if shutdown_event is not None:
if shutdown_event:
shutdown_event.set()
return JSONResponse({}, status_code=200)

Expand All @@ -261,11 +265,8 @@ async def start_shutdown() -> Any:
add_setup_failed_routes(app, started_at, msg)
return app

runner = PredictionRunner(
predictor_ref=predictor_ref,
shutdown_event=shutdown_event,
upload_url=upload_url,
)
worker = Worker(predictor_ref=predictor_ref)
runner = PredictionRunner(worker=worker)

class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)):
pass
Expand Down Expand Up @@ -360,15 +361,15 @@ def startup() -> None:
and app.state.setup_result.status == schema.Status.FAILED
):
# signal shutdown if interactive run
if not await_explicit_shutdown:
if shutdown_event is not None:
shutdown_event.set()
if shutdown_event and not await_explicit_shutdown:
shutdown_event.set()
else:
app.state.setup_task = runner.setup()
setup_task = runner.setup()
setup_task.add_done_callback(_handle_setup_done)

@app.on_event("shutdown")
def shutdown() -> None:
runner.shutdown()
worker.terminate()

@app.get("/")
async def root() -> Any:
Expand All @@ -380,12 +381,11 @@ async def root() -> Any:

@app.get("/health-check")
async def healthcheck() -> Any:
_check_setup_result()
if app.state.health == Health.READY:
health = Health.BUSY if runner.is_busy() else Health.READY
else:
health = app.state.health
setup = attrs.asdict(app.state.setup_result) if app.state.setup_result else {}
setup = app.state.setup_result.to_dict() if app.state.setup_result else {}
return jsonable_encoder({"status": health.name, "setup": setup})

@limited
Expand All @@ -403,11 +403,6 @@ async def predict(
"""
Run a single prediction on the model
"""
if runner.is_busy():
return JSONResponse(
{"detail": "Already running a prediction"}, status_code=409
)

# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

Expand Down Expand Up @@ -445,6 +440,16 @@ async def predict_idempotent(
# set on the prediction object
request.id = prediction_id

# If the prediction service is already running a prediction with a
# matching ID, return its current state.
if runner.is_busy():
task = runner.get_predict_task(request.id)
if task:
return JSONResponse(
jsonable_encoder(task.result),
status_code=202,
)

# TODO: spec-compliant parsing of Prefer header.
respond_async = prefer == "respond-async"

Expand All @@ -469,28 +474,43 @@ def _predict(
if request.input is None:
request.input = {} # pylint: disable=attribute-defined-outside-init

try:
# For now, we only ask PredictionRunner to handle file uploads for
task_kwargs = {}
if respond_async:
# For now, we only ask PredictionService to handle file uploads for
# async predictions. This is unfortunate but required to ensure
# backwards-compatible behaviour for synchronous predictions.
initial_response, async_result = runner.predict(
request,
upload=respond_async,
)
task_kwargs["upload_url"] = upload_url

try:
predict_task = runner.predict(request, task_kwargs=task_kwargs)
except RunnerBusyError:
return JSONResponse(
{"detail": "Already running a prediction"}, status_code=409
)

if hasattr(request.input, "cleanup"):
predict_task.add_done_callback(lambda _: request.input.cleanup())

predict_task.add_done_callback(_handle_predict_done)

if respond_async:
return JSONResponse(jsonable_encoder(initial_response), status_code=202)
return JSONResponse(
jsonable_encoder(predict_task.result),
status_code=202,
)

response_object = (
unwrap_pydantic_serialization_iterators(async_result.get().model_dump())
if PYDANTIC_V2
else async_result.get().dict()
)
# Otherwise, wait for the prediction to complete...
predict_task.wait()

# ...and return the result.
try:
response_object = (
unwrap_pydantic_serialization_iterators(
predict_task.result.model_dump()
)
if PYDANTIC_V2
else predict_task.result.dict()
)
_ = PredictionResponse(**response_object)
except ValidationError as e:
_log_invalid_output(e)
Expand Down Expand Up @@ -518,24 +538,30 @@ async def cancel(prediction_id: str = Path(..., title="Prediction ID")) -> Any:
return JSONResponse({}, status_code=404)
return JSONResponse({}, status_code=200)

def _check_setup_result() -> Any:
if app.state.setup_task is None:
return

if not app.state.setup_task.ready():
return
def _handle_predict_done(response: schema.PredictionResponse) -> None:
if response._fatal_exception:
_maybe_shutdown(response._fatal_exception)

result = app.state.setup_task.get()
def _handle_setup_done(setup_result: SetupResult) -> None:
app.state.setup_result = setup_result

if result.status == schema.Status.SUCCEEDED:
if app.state.setup_result.status == schema.Status.SUCCEEDED:
app.state.health = Health.READY
else:
app.state.health = Health.SETUP_FAILED

app.state.setup_result = result
# In kubernetes, mark the pod as ready now setup has completed.
probes = ProbeHelper()
probes.ready()
else:
_maybe_shutdown(Exception("setup failed"), status=Health.SETUP_FAILED)

# Reset app.state.setup_task so future calls are a no-op
app.state.setup_task = None
def _maybe_shutdown(exc: BaseException, *, status: Health = Health.DEFUNCT) -> None:
log.error("encountered fatal error", exc_info=exc)
app.state.health = status
if shutdown_event and not await_explicit_shutdown:
log.error("shutting down immediately")
shutdown_event.set()
else:
log.error("awaiting explicit shutdown")

return app

Expand Down Expand Up @@ -701,6 +727,9 @@ def _cpu_count() -> int:
s.stop()

# return error exit code when setup failed and cog is running in interactive mode (not k8s)
if app.state.setup_result and not await_explicit_shutdown:
if app.state.setup_result.status == schema.Status.FAILED:
sys.exit(-1)
if (
app.state.setup_result
and app.state.setup_result.status == schema.Status.FAILED
and not await_explicit_shutdown
):
sys.exit(-1)
Loading

0 comments on commit d7cc128

Please sign in to comment.