generated from aniketmaurya/python-project-template
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2b7a636
commit 3d3f0e4
Showing
7 changed files
with
76 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
"""A Python Project""" | ||
from .base_fastserve import BaseRequest, FastServe | ||
|
||
__version__ = "0.0.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import uvicorn | ||
|
||
from .fastserve import FastServe | ||
from .base_fastserve import FastServe | ||
|
||
serve = FastServe() | ||
serve.run_server() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .ssd import FastServeSSD |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import argparse | ||
|
||
from .ssd import FastServeSSD | ||
|
||
parser = argparse.ArgumentParser(description="Serve models with FastServe") | ||
parser.add_argument("--model", type=str, required=True, help="Name of the model") | ||
|
||
args = parser.parse_args() | ||
|
||
app = None | ||
if args.model == "ssd-1b": | ||
app = FastServeSSD(device="mps") | ||
else: | ||
raise Exception(f"FastServe.models doesn't implement model={args.model}") | ||
|
||
app.run_server() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import io | ||
from typing import List | ||
|
||
import torch | ||
from diffusers import StableDiffusionXLPipeline | ||
from fastapi.responses import StreamingResponse | ||
from pydantic import BaseModel | ||
|
||
from fastserve import BaseRequest, FastServe | ||
|
||
|
||
class PromptRequest(BaseModel): | ||
prompt: str # "An astronaut riding a green horse" | ||
negative_prompt: str = "ugly, blurry, poor quality" | ||
|
||
|
||
class FastServeSSD(FastServe): | ||
def __init__(self, batch_size=2, timeout=0.5, device="cuda") -> None: | ||
super().__init__(batch_size, timeout) | ||
self.input_schema = PromptRequest | ||
self.pipe = StableDiffusionXLPipeline.from_pretrained( | ||
"segmind/SSD-1B", | ||
torch_dtype=torch.float16, | ||
use_safetensors=True, | ||
variant="fp16", | ||
) | ||
self.pipe.to(device) | ||
|
||
def handle(self, batch: List[PromptRequest]) -> List[StreamingResponse]: | ||
prompts = [b.prompt for b in batch] | ||
negative_prompts = [b.negative_prompt for b in batch] | ||
|
||
pil_images = self.pipe( | ||
prompt=prompts, negative_prompt=negative_prompts, num_inference_steps=1 | ||
).images | ||
image_bytes_list = [] | ||
for pil_image in pil_images: | ||
image_bytes = io.BytesIO() | ||
pil_image.save(image_bytes, format="JPEG") | ||
image_bytes_list.append(image_bytes) | ||
return [ | ||
StreamingResponse(image_bytes, media_type="image/jpeg") | ||
for image_bytes in image_bytes_list | ||
] |