Skip to content
This repository has been archived by the owner on Jun 25, 2023. It is now read-only.

Commit

Permalink
Merge pull request #5 from electriclizard/splitted-tokenization
Browse files Browse the repository at this point in the history
Splitted tokenization
  • Loading branch information
electriclizard authored Jun 6, 2023
2 parents e13fdea + b3d0281 commit 2a44f61
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 29 deletions.
2 changes: 1 addition & 1 deletion solution/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ RUN apt-get update && apt upgrade -y && \

COPY . $WORKDIR

ENTRYPOINT [ "python3", "app.py" ]
ENTRYPOINT [ "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8080", "--workers", "1" ]
37 changes: 25 additions & 12 deletions solution/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import List
from configs.config import AppConfig, ModelConfig

import asyncio

import uvicorn
Expand All @@ -10,22 +8,18 @@
from fastapi.responses import HTMLResponse
from starlette.requests import Request

from configs.config import AppConfig, ModelConfig
from infrastructure.models import TransformerTextClassificationModel
from service.recognition import TextClassificationService
from handlers.recognition import PredictionHandler
from handlers.data_models import ResponseSchema


def build_models(model_configs: List[ModelConfig]) -> List[TransformerTextClassificationModel]:
models = [
config = AppConfig.parse_file("./configs/app_config.yaml")
models = [
TransformerTextClassificationModel(conf.model, conf.model_path, conf.tokenizer)
for conf in model_configs
for conf in config.models
]
return models


config = AppConfig.parse_file("./configs/app_config.yaml")
models = build_models(config.models)

recognition_service = TextClassificationService(models)
recognition_handler = PredictionHandler(recognition_service, config.timeout)
Expand All @@ -35,12 +29,31 @@ def build_models(model_configs: List[ModelConfig]) -> List[TransformerTextClassi


@app.on_event("startup")
async def create_queue():
async def count_max_batch_size():
print("Calculating Max batch size")
batch_size = 100

try:
while True:
text = ["this is simple text"]*batch_size
inputs = [model.tokenize_texts(text) for model in models]
outputs = [model(m_inputs) for model, m_inputs in zip(models, inputs)]
batch_size += 100

except RuntimeError as err:
if "CUDA out of memory" in str(err):
batch_size -= 100
app.max_batch_size = batch_size
print(f"Max batch size calculated = {app.max_batch_size}")


@app.on_event("startup")
def create_queues():
app.models_queues = {}
for md in models:
task_queue = asyncio.Queue()
app.models_queues[md.name] = task_queue
asyncio.create_task(recognition_handler.handle(md.name, task_queue))
asyncio.create_task(recognition_handler.handle(md.name, task_queue, app.max_batch_size))


@router.post("/process", response_model=ResponseSchema)
Expand Down
15 changes: 11 additions & 4 deletions solution/handlers/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ def __init__(self, recognition_service: TextClassificationService, timeout: floa
self.recognition_service = recognition_service
self.timeout = timeout

async def handle(self, model_name, model_queue):
async def handle(self, model_name, model_queue, max_batch_size: int):
while True:
inputs = None
texts = []
queues = []

Expand All @@ -33,9 +34,11 @@ async def handle(self, model_name, model_queue):
None
)
if model:
outs = model(texts)
for rq, out in zip(queues, outs):
await rq.put(out)
for text_batch in self._perform_batches(texts, max_batch_size):
inputs = model.tokenize_texts(texts)
outs = model(inputs)
for rq, out in zip(queues, outs):
await rq.put(out)

def serialize_answer(self, results: List[TextClassificationModelData]) -> ResponseSchema:
res_model = {rec.model_name: self._recognitions_to_schema(rec) for rec in results}
Expand All @@ -46,3 +49,7 @@ def _recognitions_to_schema(self, recognition: TextClassificationModelData) -> R
recognition.label = recognition.label.upper()
return RecognitionSchema(score=recognition.score, label=recognition.label)

def _perform_batches(self, texts: List[str], max_batch_size):
for i in range(0, len(texts), max_batch_size):
yield texts[i:i + max_batch_size]

48 changes: 36 additions & 12 deletions solution/infrastructure/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import List

import torch
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification



@dataclass
Expand All @@ -21,30 +22,53 @@ def __init__(self, name: str, model_path: str, tokenizer: str):
self.model_path = model_path
self.tokenizer = tokenizer
self.device = 0 if torch.cuda.is_available() else -1
self.model = self._load_model()
self._load_model()

@abstractmethod
def _load_model(self) -> Callable:
def _load_model(self):
...

@abstractmethod
def __call__(self, input_texts: List[str]) -> List[TextClassificationModelData]:
def __call__(self, inputs) -> List[TextClassificationModelData]:
...


class TransformerTextClassificationModel(BaseTextClassificationModel):

def _load_model(self):
sentiment_task = pipeline(
"sentiment-analysis",
model=self.model_path,
tokenizer=self.model_path,
device=self.device
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
self.model = self.model.to(self.device)

def tokenize_texts(self, texts: List[str]):
inputs = self.tokenizer.batch_encode_plus(
texts,
add_special_tokens=True,
padding='longest',
truncation=True,
return_token_type_ids=True,
return_tensors='pt'
)
return sentiment_task
inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to GPU
return inputs

def _results_from_logits(self, logits: torch.Tensor):
id2label = self.model.config.id2label

label_ids = logits.argmax(dim=1)
scores = logits.softmax(dim=-1)
results = [
{
"label": id2label[label_id.item()],
"score": score[label_id.item()].item()
}
for label_id, score in zip(label_ids, scores)
]
return results

def __call__(self, input_texts: List[str]) -> List[TextClassificationModelData]:
predictions = self.model(input_texts, batch_size=len(input_texts))
def __call__(self, inputs) -> List[TextClassificationModelData]:
logits = self.model(**inputs).logits
predictions = self._results_from_logits(logits)
predictions = [TextClassificationModelData(self.name, **prediction) for prediction in predictions]
return predictions

0 comments on commit 2a44f61

Please sign in to comment.