Skip to content

Commit

Permalink
Merge branch 'rvankoert:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
rvankoert authored Nov 13, 2023
2 parents 938ed63 + c1f8ddb commit 7c1d8ac
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 122 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ GUNICORN_ACCESSLOG # Default: "-": Access log settings.

```bash
LOGHI_MODEL_PATH # Path to the model.
LOGHI_CHARLIST_PATH # Path to the character list.
LOGHI_BATCH_SIZE # Default: "256": Batch size for processing.
LOGHI_OUTPUT_PATH # Directory where predictions are saved.
LOGHI_MAX_QUEUE_SIZE # Default: "10000": Maximum size of the processing queue.
Expand All @@ -326,7 +325,14 @@ Once the API is up and running, you can send HTR requests using curl. Here's how
curl -X POST -F "image=@$input_path" -F "group_id=$group_id" -F "identifier=$filename" http://localhost:5000/predict
```

Replace `$input_path`, `$group_id`, and `$filename` with your specific values. The model processes the image, predicts the handwritten text, and saves the predictions in the specified output path (from the `LOGHI_OUTPUT_PATH` environment variable).
Replace `$input_path`, `$group_id`, and `$filename` with your respective file paths and identifiers. If you're considering switching the recognition model, use the `model` field cautiously:

- The `model` field (`-F "model=$model_path"`) allows for specifying which handwritten text recognition model the API should use for the current request.
- To avoid the slowdown associated with loading different models for each request, it is preferable to set a specific model before starting your API by using the `LOGHI_MODEL_PATH` environment variable.
- Only use the `model` field if you are certain that a different model is needed for a particular task and you understand its performance characteristics.

> [!WARNING]
> Continuous model switching with `$model_path` can lead to severe processing delays. For most users, it's best to set the `LOGHI_MODEL_PATH` once and use the same model consistently, restarting the API with a new variable only when necessary.
---

Expand Down
27 changes: 17 additions & 10 deletions src/api/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,27 @@ def setup_logging(level: str = "INFO") -> logging.Logger:
return logging.getLogger(__name__)


def extract_request_data() -> Tuple[bytes, str, str]:
def extract_request_data() -> Tuple[bytes, str, str, str]:
"""
Extract image and other form data from the current request.
Returns
-------
tuple of (bytes, str, str)
tuple of (bytes, str, str, str)
image_content : bytes
Content of the uploaded image.
group_id : str
ID of the group from form data.
identifier : str
Identifier from form data.
model : str
Location of the model to use for prediction.
Raises
------
ValueError
If required data (image, group_id, identifier) is missing or if the
image format is invalid.
If required data (image, group_id, identifier, model) is missing or if
the image format is invalid.
"""

# Extract the uploaded image
Expand All @@ -106,7 +108,12 @@ def extract_request_data() -> Tuple[bytes, str, str]:
if not identifier:
raise ValueError("No identifier provided.")

return image_content, group_id, identifier
model = request.form.get('model')
if model:
if not os.path.exists(model):
raise ValueError(f"Model directory {model} does not exist.")

return image_content, group_id, identifier, model


def get_env_variable(var_name: str, default_value: str = None) -> str:
Expand Down Expand Up @@ -150,19 +157,21 @@ def get_env_variable(var_name: str, default_value: str = None) -> str:
return value


def start_processes(batch_size: int, max_queue_size: int, model_path: str,
charlist_path: str, output_path: str, gpus: str):
def start_processes(batch_size: int, max_queue_size: int,
output_path: str, gpus: str, model_path: str):
logger = logging.getLogger(__name__)

# Create a thread-safe Queue
logger.info("Initializing request queue")
manager = Manager()
request_queue = manager.JoinableQueue(maxsize=max_queue_size//2)
logger.info(f"Request queue size: {max_queue_size//2}")

# Max size of prepared queue is half of the max size of request queue
# expressed in number of batches
max_prepared_queue_size = max_queue_size // 2 // batch_size
prepared_queue = manager.JoinableQueue(maxsize=max_prepared_queue_size)
logger.info(f"Prediction queue size: {max_prepared_queue_size}")

# Start the image preparation process
logger.info("Starting image preparation process")
Expand All @@ -178,9 +187,7 @@ def start_processes(batch_size: int, max_queue_size: int, model_path: str,
logger.info("Starting batch prediction process")
prediction_process = Process(
target=batch_prediction_worker,
args=(prepared_queue, model_path,
charlist_path, output_path,
gpus),
args=(prepared_queue, output_path, model_path, gpus),
name="Batch Prediction Process")
prediction_process.daemon = True
prediction_process.start()
Expand Down
Loading

0 comments on commit 7c1d8ac

Please sign in to comment.