-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcaviar_yolo.py
45 lines (33 loc) · 1013 Bytes
/
caviar_yolo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Based on https://github.com/WongKinYiu/YOLO/blob/main/examples/notebook_inference.ipynb
import sys
from pathlib import Path
import torch
from hydra import compose, initialize
project_root = Path().resolve().parent
sys.path.append(str(project_root))
from yolo import (
FastModelLoader,
AugmentationComposer,
Config,
PostProcess,
create_converter,
)
CONFIG_PATH = "./yolo_config"
CONFIG_NAME = "config"
MODEL = "v9-c"
DEVICE = "cuda:0"
device = torch.device(DEVICE)
with initialize(config_path=CONFIG_PATH, version_base=None, job_name="caviar"):
cfg: Config = compose(
config_name=CONFIG_NAME,
overrides=[
"task=inference",
f"model={MODEL}",
],
)
model = FastModelLoader(cfg).load_model(device)
transform = AugmentationComposer([], cfg.image_size)
converter = create_converter(
cfg.model.name, model, cfg.model.anchor, cfg.image_size, device
)
post_proccess = PostProcess(converter, cfg.task.nms)