Skip to content
This repository has been archived by the owner on Jun 25, 2023. It is now read-only.

Commit

Permalink
Merge pull request #3 from electriclizard/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
electriclizard authored May 20, 2023
2 parents 06e3469 + ae6a534 commit e13fdea
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 22 deletions.
24 changes: 20 additions & 4 deletions solution/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import List
from configs.config import AppConfig, ModelConfig

import asyncio

import uvicorn
from fastapi import FastAPI, APIRouter
from fastapi.openapi.docs import get_swagger_ui_html
Expand All @@ -26,18 +28,32 @@ def build_models(model_configs: List[ModelConfig]) -> List[TransformerTextClassi
models = build_models(config.models)

recognition_service = TextClassificationService(models)
recognition_handler = PredictionHandler(recognition_service)
recognition_handler = PredictionHandler(recognition_service, config.timeout)

app = FastAPI()
router = APIRouter()


@app.on_event("startup")
async def create_queue():
app.models_queues = {}
for md in models:
task_queue = asyncio.Queue()
app.models_queues[md.name] = task_queue
asyncio.create_task(recognition_handler.handle(md.name, task_queue))


@router.post("/process", response_model=ResponseSchema)
async def process(request: Request):
text = (await request.body()).decode()
# call handler
result = recognition_handler.handle(text)
return result

results = []
response_q = asyncio.Queue() # init a response queue for every request, one for all models
for model_name, model_queue in app.models_queues.items():
await model_queue.put((text, response_q))
model_res = await response_q.get()
results.append(model_res)
return recognition_handler.serialize_answer(results)


app.include_router(router)
Expand Down
3 changes: 3 additions & 0 deletions solution/configs/app_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ models:

port: 8080
workers: 1

timeout: 0.01

2 changes: 2 additions & 0 deletions solution/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ class AppConfig(YamlModel):
# app parameters
port: int
workers: int
# async queues parameters
timeout: float

36 changes: 28 additions & 8 deletions solution/handlers/recognition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
import asyncio

from pydantic import ValidationError

Expand All @@ -9,17 +10,36 @@

class PredictionHandler:

def __init__(self, recognition_service: TextClassificationService):
def __init__(self, recognition_service: TextClassificationService, timeout: float):
self.recognition_service = recognition_service

def handle(self, body: str) -> ResponseSchema:
query_results = self.recognition_service.get_results(body)
result = self.serialize_answer(query_results)
return result
self.timeout = timeout

async def handle(self, model_name, model_queue):
while True:
texts = []
queues = []

try:
while True:
(text, response_queue) = await asyncio.wait_for(model_queue.get(), timeout=self.timeout)
texts.append(text)
queues.append(response_queue)
except asyncio.exceptions.TimeoutError:
pass

if texts:
model = next(
(model for model in self.recognition_service.service_models if model.name == model_name),
None
)
if model:
outs = model(texts)
for rq, out in zip(queues, outs):
await rq.put(out)

def serialize_answer(self, results: List[TextClassificationModelData]) -> ResponseSchema:
results = {rec.model_name: self._recognitions_to_schema(rec) for rec in results}
return ResponseSchema(**results)
res_model = {rec.model_name: self._recognitions_to_schema(rec) for rec in results}
return ResponseSchema(**res_model)

def _recognitions_to_schema(self, recognition: TextClassificationModelData) -> RecognitionSchema:
if recognition.model_name != "ivanlau":
Expand Down
14 changes: 6 additions & 8 deletions solution/infrastructure/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import List

import torch
from transformers import pipeline
Expand All @@ -27,7 +28,7 @@ def _load_model(self) -> Callable:
...

@abstractmethod
def __call__(self, input_text: str) -> TextClassificationModelData:
def __call__(self, input_texts: List[str]) -> List[TextClassificationModelData]:
...


Expand All @@ -42,11 +43,8 @@ def _load_model(self):
)
return sentiment_task

def __call__(self, input_text: str) -> TextClassificationModelData:
if isinstance(input_text, str):
prediction = self.model(input_text)[0]
prediction = TextClassificationModelData(self.name, **prediction)
return prediction
else:
raise TypeError("Model input text must be str type")
def __call__(self, input_texts: List[str]) -> List[TextClassificationModelData]:
predictions = self.model(input_texts, batch_size=len(input_texts))
predictions = [TextClassificationModelData(self.name, **prediction) for prediction in predictions]
return predictions

4 changes: 2 additions & 2 deletions solution/service/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TextClassificationService:
def __init__(self, models: List[BaseTextClassificationModel]):
self.service_models = models

def get_results(self, input_text: str) -> List[TextClassificationModelData]:
results = [model(input_text) for model in self.service_models]
def get_results(self, input_texts: List[str]) -> List[List[TextClassificationModelData]]:
results = [model(input_texts) for model in self.service_models]
return results

0 comments on commit e13fdea

Please sign in to comment.