diff --git a/src/benchmarks/locustfile.py b/src/benchmarks/locustfile.py new file mode 100644 index 0000000..2d0f90b --- /dev/null +++ b/src/benchmarks/locustfile.py @@ -0,0 +1,19 @@ +import requests +from locust import HttpUser, task + +headers = { + "accept": "application/json", + "Content-Type": "application/json", +} + + +class FastServePerfUser(HttpUser): + @task + def hello_world(self): + data = { + "prompt": "An astronaut riding a green horse", + "negative_prompt": "ugly, blurry, poor quality", + } + base_url = self.client.base_url + response = self.client.post("/endpoint", headers=headers, json=data) + response.raise_for_status() diff --git a/src/fastserve/base_fastserve.py b/src/fastserve/base_fastserve.py index c210ccb..9631cb3 100644 --- a/src/fastserve/base_fastserve.py +++ b/src/fastserve/base_fastserve.py @@ -1,3 +1,4 @@ +import logging from contextlib import asynccontextmanager from typing import Any, List @@ -6,6 +7,12 @@ from .batching import BatchProcessor +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], +) + class BaseRequest(BaseModel): request: Any diff --git a/src/fastserve/batching.py b/src/fastserve/batching.py index 5d64183..40df613 100644 --- a/src/fastserve/batching.py +++ b/src/fastserve/batching.py @@ -1,13 +1,14 @@ +import logging import random +import signal import time import uuid from dataclasses import dataclass, field -from logging import INFO, Logger from queue import Empty, Queue from threading import Event, Thread from typing import Any, Callable, Dict, List -logger = Logger(__name__, level=INFO) +logger = logging.getLogger(__name__) class BatchedQueue: @@ -98,24 +99,24 @@ def __init__( self.func = func self._event = Event() self._cancel_signal = Event() + signal.signal(signal.SIGINT, self.signal_handler) - self._thread = Thread(target=self._process_queue) + self._thread = Thread(target=self._process_queue, daemon=True) self._thread.start() def _process_queue(self): logger.info("Started processing") while True: if self._cancel_signal.is_set(): - logger.info("Stopped batch processor") return t0 = time.time() batch: List[WaitedObject] = self._batched_queue.get() logger.debug(batch) t1 = time.time() - logger.debug(f"waited {t1-t0:.2f}s for batch") if not batch: logger.debug("no batch") continue + logger.info(f"Aggregated batch size {len(batch)} in {t1-t0:.2f}s") batch_items = [b.item for b in batch] logger.debug(batch_items) results = self.func(batch_items) @@ -135,3 +136,7 @@ def cancel(self): self._cancel_signal.set() self._thread.join() logger.info("Batch Processor terminated!") + + def signal_handler(self, sig, frame): + logger.info("Received signal to terminate the thread.") + self.cancel() diff --git a/src/fastserve/models/__main__.py b/src/fastserve/models/__main__.py index 6f44aa1..6410db6 100644 --- a/src/fastserve/models/__main__.py +++ b/src/fastserve/models/__main__.py @@ -7,6 +7,20 @@ parser = argparse.ArgumentParser(description="Serve models with FastServe") parser.add_argument("--model", type=str, required=True, help="Name of the model") parser.add_argument("--device", type=str, required=False, help="Device") +parser.add_argument( + "--batch_size", + type=int, + default=1, + required=False, + help="Maximum batch size for the ML endpoint", +) +parser.add_argument( + "--timeout", + type=float, + default=0.0, + required=False, + help="Timeout to aggregate maximum batch size", +) args = parser.parse_args() @@ -15,7 +29,7 @@ device = args.device or get_default_device() if args.model == "ssd-1b": - app = FastServeSSD(device=device) + app = FastServeSSD(device=device, timeout=args.timeout, batch_size=args.batch_size) else: raise Exception(f"FastServe.models doesn't implement model={args.model}") diff --git a/src/fastserve/models/ssd.py b/src/fastserve/models/ssd.py index 16803de..a1e7567 100644 --- a/src/fastserve/models/ssd.py +++ b/src/fastserve/models/ssd.py @@ -4,9 +4,8 @@ import torch from diffusers import StableDiffusionXLPipeline from fastapi.responses import StreamingResponse -from pydantic import BaseModel - from fastserve import BaseRequest, FastServe +from pydantic import BaseModel class PromptRequest(BaseModel): @@ -16,7 +15,7 @@ class PromptRequest(BaseModel): class FastServeSSD(FastServe): def __init__( - self, batch_size=2, timeout=0.5, device="cuda", num_inference_steps: int = 1 + self, batch_size=2, timeout=0.5, device="cuda", num_inference_steps: int = 50 ) -> None: super().__init__(batch_size, timeout) self.num_inference_steps = num_inference_steps