Skip to content

Commit

Permalink
feat: add torchscript runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
CuriousDolphin committed Jan 28, 2025
1 parent 1e7596c commit 3956533
Show file tree
Hide file tree
Showing 12 changed files with 561 additions and 366 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ venv:
@uv venv --python=python3.12

install: .uv .pre-commit
@uv pip install -e ".[cpu,dev]"
@uv pip install -e ".[dev]" --no-cache-dir
@pre-commit install

install-gpu: .uv .pre-commit
@uv pip install -e ".[dev,gpu]"
@uv pip install -e ".[dev,onnx,tensorrt,torch]" --no-cache-dir
@pre-commit install

lint:
Expand Down
63 changes: 41 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Focoos Foundational Models
# Focoos pre-trained models

| Model Name | Task | Metrics | Domain |
| ------------------- | --------------------- | ------- | ------------------------------- |
Expand All @@ -14,50 +14,69 @@
| focoos_isaid_nano | Semantic Segmentation | - | Satellite Imagery, 15 classes |
| focoos_isaid_medium | Semantic Segmentation | - | Satellite Imagery, 15 classes |

# Focoos SDK
# Focoos
Focoos is a comprehensive SDK designed for computer vision tasks such as object detection, semantic segmentation, instance segmentation, and more. It provides pre-trained models that can be easily integrated and customized by users for various applications.
Focoos supports both cloud and local inference, and enables training on the cloud, making it a versatile tool for developers working in different domains, including autonomous driving, common scenes, drone aerial scenes, and satellite imagery.

![Tests](https://github.com/FocoosAI/focoos/actions/workflows/test.yml/badge.svg??event=push&branch=main)

## Requirements

### CUDA 12
### Key Features

For **local inference**, ensure that you have CUDA 12 and cuDNN 9 installed, as they are required for onnxruntime version 1.20.1.
- **Pre-trained Models**: A wide range of pre-trained models for different tasks and domains.
- **Multiple Inference Runtimes**: Support for various inference runtimes including CPU, GPU, Torchscript CUDA, OnnxRuntime CUDA, and OnnxRuntime TensorRT.
- **Cloud Inference**: API to Focoos cloud inference.
- **Local Inference**: local inference, making it easy to deploy models on the local machine.
- **Cloud Training**: Train user models on the focoos cloud.
- **Model Monitoring**: Monitor model performance and metrics.

To install cuDNN 9:
![Tests](https://github.com/FocoosAI/focoos/actions/workflows/test.yml/badge.svg??event=push&branch=main)

# 🐍 Setup
We recommend using [UV](https://docs.astral.sh/uv/) as a package manager and environment manager for a streamlined dependency management experience.
Here’s how to create a new virtual environment with UV:
```bash
apt-get -y install cudnn9-cuda-12
pip install uv
uv venv --python 3.12
source .venv/bin/activate
```

### (Optional) TensorRT
Focoos models support multiple inference runtimes.
To keep the library lightweight, optional dependencies (e.g., torch, onnxruntime, tensorrt) are not installed by default.
You can install the required optional dependencies using the following syntax:

To perform inference using TensorRT, ensure you have TensorRT version 10.5 installed.
## CPU only or Remote Usage

```bash
sudo apt-get install tensorrt
uv pip install focoos git+https://github.com/FocoosAI/focoos.git
```

# Install
## GPU Runtimes
### Torchscript CUDA
```bash
uv pip install focoos[torch] git+https://github.com/FocoosAI/focoos.git
```

Nvidia GPU:
### OnnxRuntime CUDA
ensure that you have CUDA 12 and cuDNN 9 installed, as they are required for onnxruntime version 1.20.1.

```bash
pip install '.[gpu]'
apt-get -y install cudnn9-cuda-12
```

Nvidia GPU,TensorRT:

```bash
pip install '.[gpu,tensorrt]'
uv pip install focoos[onnx] git+https://github.com/FocoosAI/focoos.gi
```

CPU,COREML:
### OnnxRuntime TensorRT

To perform inference using TensorRT, ensure you have TensorRT version 10.5 installed.
```bash
sudo apt-get install tensorrt
```

```bash
pip install '.[cpu]'
uv pip install focoos[tensorrt] git+https://github.com/FocoosAI/focoos.git
```


## 🤖 Cloud Inference

```python
Expand All @@ -74,7 +93,7 @@ detections = model.infer("./image.jpg", threshold=0.4)
setup FOCOOS_API_KEY_GRADIO environment variable with your Focoos API key

```bash
pip install '.[gradio]'
uv pip install focoos[gradio] git+https://github.com/FocoosAI/focoos.git
```

```bash
Expand Down
8 changes: 4 additions & 4 deletions focoos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
ModelMetadata,
ModelPreview,
ModelStatus,
OnnxEngineOpts,
OnnxRuntimeOpts,
RuntimeTypes,
SystemInfo,
TrainingInfo,
TrainInstance,
)
from .remote_model import RemoteModel
from .runtime import ONNXRuntime, get_runtime
from .runtime import ONNXRuntime, load_runtime
from .utils.logger import get_logger
from .utils.system import get_system_info
from .utils.vision import (
Expand Down Expand Up @@ -57,14 +57,14 @@
"Hyperparameters",
"LatencyMetrics",
"ModelPreview",
"OnnxEngineOpts",
"OnnxRuntimeOpts",
"RuntimeTypes",
"SystemInfo",
"TrainingInfo",
"TrainInstance",
"get_system_info",
"ONNXRuntime",
"get_runtime",
"load_runtime",
"DEV_API_URL",
"LOCAL_API_URL",
"PROD_API_URL",
Expand Down
17 changes: 11 additions & 6 deletions focoos/focoos.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from focoos.local_model import LocalModel
from focoos.ports import (
DatasetMetadata,
ModelFormat,
ModelMetadata,
ModelNotFound,
ModelPreview,
Expand Down Expand Up @@ -164,7 +165,7 @@ def list_focoos_models(self) -> list[ModelPreview]:
def get_local_model(
self,
model_ref: str,
runtime_type: Optional[RuntimeTypes] = None,
runtime_type: Optional[RuntimeTypes] = RuntimeTypes.ONNX_CUDA32,
) -> LocalModel:
"""
Retrieves a local model for the specified reference.
Expand All @@ -187,8 +188,12 @@ def get_local_model(
"""
runtime_type = runtime_type or FOCOOS_CONFIG.runtime_type
model_dir = os.path.join(self.cache_dir, model_ref)
if not os.path.exists(os.path.join(model_dir, "model.onnx")):
self._download_model(model_ref)
format = ModelFormat.TORCHSCRIPT if runtime_type == RuntimeTypes.TORCHSCRIPT_32 else ModelFormat.ONNX
if not os.path.exists(os.path.join(model_dir, f"model.{format.value}")):
self._download_model(
model_ref,
format=format,
)
return LocalModel(model_dir, runtime_type)

def get_remote_model(self, model_ref: str) -> RemoteModel:
Expand Down Expand Up @@ -249,7 +254,7 @@ def list_shared_datasets(self) -> list[DatasetMetadata]:
raise ValueError(f"Failed to list datasets: {res.status_code} {res.text}")
return [DatasetMetadata.from_json(dataset) for dataset in res.json()]

def _download_model(self, model_ref: str) -> str:
def _download_model(self, model_ref: str, format: ModelFormat = ModelFormat.ONNX) -> str:
"""
Downloads a model from the Focoos API.
Expand All @@ -263,14 +268,14 @@ def _download_model(self, model_ref: str) -> str:
ValueError: If the API request fails or the download fails.
"""
model_dir = os.path.join(self.cache_dir, model_ref)
model_path = os.path.join(model_dir, "model.onnx")
model_path = os.path.join(model_dir, f"model.{format.value}")
metadata_path = os.path.join(model_dir, "focoos_metadata.json")
if os.path.exists(model_path) and os.path.exists(metadata_path):
logger.info("📥 Model already downloaded")
return model_path

## download model metadata
res = self.http_client.get(f"models/{model_ref}/download?format=onnx")
res = self.http_client.get(f"models/{model_ref}/download?format={format.value}")
if res.status_code != 200:
logger.error(f"Failed to download model: {res.status_code} {res.text}")
raise ValueError(f"Failed to download model: {res.status_code} {res.text}")
Expand Down
25 changes: 19 additions & 6 deletions focoos/local_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
FocoosDetections,
FocoosTask,
LatencyMetrics,
ModelFormat,
ModelMetadata,
RuntimeTypes,
)
from focoos.runtime import ONNXRuntime, get_runtime
from focoos.runtime import BaseRuntime, load_runtime
from focoos.utils.logger import get_logger
from focoos.utils.vision import (
image_preprocess,
Expand Down Expand Up @@ -82,20 +83,32 @@ def __init__(
and initializes the runtime for inference using the provided runtime type. Annotation
utilities are also prepared for visualizing model outputs.
"""
# Determine runtime type and model format
runtime_type = runtime_type or FOCOOS_CONFIG.runtime_type
model_format = ModelFormat.TORCHSCRIPT if runtime_type == RuntimeTypes.TORCHSCRIPT_32 else ModelFormat.ONNX

logger.debug(f"Runtime type: {runtime_type}, Loading model from {model_dir},")
if not os.path.exists(model_dir):
raise FileNotFoundError(f"Model directory not found: {model_dir}")
# Set model directory and path
self.model_dir: Union[str, Path] = model_dir
self.model_path = os.path.join(model_dir, f"model.{model_format.value}")
logger.debug(f"Runtime type: {runtime_type}, Loading model from {self.model_path}..")

# Check if model path exists
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model path not found: {self.model_path}")

# Load metadata and set model reference
self.metadata: ModelMetadata = self._read_metadata()
self.model_ref = self.metadata.ref

# Initialize annotation utilities
self.label_annotator = sv.LabelAnnotator(text_padding=10, border_radius=10)
self.box_annotator = sv.BoxAnnotator()
self.mask_annotator = sv.MaskAnnotator()
self.runtime: ONNXRuntime = get_runtime(

# Load runtime for inference
self.runtime: BaseRuntime = load_runtime(
runtime_type,
str(os.path.join(model_dir, "model.onnx")),
str(self.model_path),
self.metadata,
FOCOOS_CONFIG.warmup_iter,
)
Expand Down
20 changes: 19 additions & 1 deletion focoos/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class FocoosDetections(FocoosBaseModel):


@dataclass
class OnnxEngineOpts:
class OnnxRuntimeOpts:
fp16: Optional[bool] = False
cuda: Optional[bool] = False
vino: Optional[bool] = False
Expand All @@ -221,6 +221,13 @@ class OnnxEngineOpts:
warmup_iter: int = 0


@dataclass
class TorchscriptRuntimeOpts:
warmup_iter: int = 0
optimize_for_inference: bool = True
set_fusion_strategy: bool = True


@dataclass
class LatencyMetrics:
fps: int
Expand All @@ -239,6 +246,12 @@ class RuntimeTypes(str, Enum):
ONNX_TRT16 = "onnx_trt16"
ONNX_CPU = "onnx_cpu"
ONNX_COREML = "onnx_coreml"
TORCHSCRIPT_32 = "torchscript_32"


class ModelFormat(str, Enum):
ONNX = "onnx"
TORCHSCRIPT = "pt"


class GPUInfo(FocoosBaseModel):
Expand Down Expand Up @@ -266,6 +279,7 @@ class SystemInfo(FocoosBaseModel):
gpu_cuda_version: Optional[str] = None
gpus_info: Optional[list[GPUInfo]] = None
packages_versions: Optional[dict[str, str]] = None
environment: Optional[dict[str, str]] = None

def pretty_print(self):
print("================ SYSTEM INFO ====================")
Expand All @@ -286,6 +300,10 @@ def pretty_print(self):
print(f"{key}:")
for pkg_name, pkg_version in value.items():
print(f" - {pkg_name}: {pkg_version}")
elif isinstance(value, dict) and key == "environment": # Special formatting for environment
print(f"{key}:")
for env_key, env_value in value.items():
print(f" - {env_key}: {env_value}")
else:
print(f"{key}: {value}")
print("================================================")
Expand Down
Loading

0 comments on commit 3956533

Please sign in to comment.