diff --git a/solution/app.py b/solution/app.py index d084f61..702be3e 100644 --- a/solution/app.py +++ b/solution/app.py @@ -1,6 +1,8 @@ from typing import List from configs.config import AppConfig, ModelConfig +import asyncio + import uvicorn from fastapi import FastAPI, APIRouter from fastapi.openapi.docs import get_swagger_ui_html @@ -26,18 +28,32 @@ def build_models(model_configs: List[ModelConfig]) -> List[TransformerTextClassi models = build_models(config.models) recognition_service = TextClassificationService(models) -recognition_handler = PredictionHandler(recognition_service) +recognition_handler = PredictionHandler(recognition_service, config.timeout) app = FastAPI() router = APIRouter() +@app.on_event("startup") +async def create_queue(): + app.models_queues = {} + for md in models: + task_queue = asyncio.Queue() + app.models_queues[md.name] = task_queue + asyncio.create_task(recognition_handler.handle(md.name, task_queue)) + + @router.post("/process", response_model=ResponseSchema) async def process(request: Request): text = (await request.body()).decode() - # call handler - result = recognition_handler.handle(text) - return result + + results = [] + response_q = asyncio.Queue() # init a response queue for every request, one for all models + for model_name, model_queue in app.models_queues.items(): + await model_queue.put((text, response_q)) + model_res = await response_q.get() + results.append(model_res) + return recognition_handler.serialize_answer(results) app.include_router(router) diff --git a/solution/configs/app_config.yaml b/solution/configs/app_config.yaml index 987f42c..aea7f52 100644 --- a/solution/configs/app_config.yaml +++ b/solution/configs/app_config.yaml @@ -17,3 +17,6 @@ models: port: 8080 workers: 1 + +timeout: 0.01 + diff --git a/solution/configs/config.py b/solution/configs/config.py index 97cc56f..b4ebadc 100644 --- a/solution/configs/config.py +++ b/solution/configs/config.py @@ -15,4 +15,6 @@ class AppConfig(YamlModel): # app parameters port: int workers: int + # async queues parameters + timeout: float diff --git a/solution/handlers/recognition.py b/solution/handlers/recognition.py index d536190..2913918 100644 --- a/solution/handlers/recognition.py +++ b/solution/handlers/recognition.py @@ -1,4 +1,5 @@ from typing import List +import asyncio from pydantic import ValidationError @@ -9,17 +10,36 @@ class PredictionHandler: - def __init__(self, recognition_service: TextClassificationService): + def __init__(self, recognition_service: TextClassificationService, timeout: float): self.recognition_service = recognition_service - - def handle(self, body: str) -> ResponseSchema: - query_results = self.recognition_service.get_results(body) - result = self.serialize_answer(query_results) - return result + self.timeout = timeout + + async def handle(self, model_name, model_queue): + while True: + texts = [] + queues = [] + + try: + while True: + (text, response_queue) = await asyncio.wait_for(model_queue.get(), timeout=self.timeout) + texts.append(text) + queues.append(response_queue) + except asyncio.exceptions.TimeoutError: + pass + + if texts: + model = next( + (model for model in self.recognition_service.service_models if model.name == model_name), + None + ) + if model: + outs = model(texts) + for rq, out in zip(queues, outs): + await rq.put(out) def serialize_answer(self, results: List[TextClassificationModelData]) -> ResponseSchema: - results = {rec.model_name: self._recognitions_to_schema(rec) for rec in results} - return ResponseSchema(**results) + res_model = {rec.model_name: self._recognitions_to_schema(rec) for rec in results} + return ResponseSchema(**res_model) def _recognitions_to_schema(self, recognition: TextClassificationModelData) -> RecognitionSchema: if recognition.model_name != "ivanlau": diff --git a/solution/infrastructure/models.py b/solution/infrastructure/models.py index d63f8ed..5ed6729 100644 --- a/solution/infrastructure/models.py +++ b/solution/infrastructure/models.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass +from typing import List import torch from transformers import pipeline @@ -27,7 +28,7 @@ def _load_model(self) -> Callable: ... @abstractmethod - def __call__(self, input_text: str) -> TextClassificationModelData: + def __call__(self, input_texts: List[str]) -> List[TextClassificationModelData]: ... @@ -42,11 +43,8 @@ def _load_model(self): ) return sentiment_task - def __call__(self, input_text: str) -> TextClassificationModelData: - if isinstance(input_text, str): - prediction = self.model(input_text)[0] - prediction = TextClassificationModelData(self.name, **prediction) - return prediction - else: - raise TypeError("Model input text must be str type") + def __call__(self, input_texts: List[str]) -> List[TextClassificationModelData]: + predictions = self.model(input_texts, batch_size=len(input_texts)) + predictions = [TextClassificationModelData(self.name, **prediction) for prediction in predictions] + return predictions diff --git a/solution/service/recognition.py b/solution/service/recognition.py index 47adbba..d72cbb1 100644 --- a/solution/service/recognition.py +++ b/solution/service/recognition.py @@ -10,7 +10,7 @@ class TextClassificationService: def __init__(self, models: List[BaseTextClassificationModel]): self.service_models = models - def get_results(self, input_text: str) -> List[TextClassificationModelData]: - results = [model(input_text) for model in self.service_models] + def get_results(self, input_texts: List[str]) -> List[List[TextClassificationModelData]]: + results = [model(input_texts) for model in self.service_models] return results