forked from AndreyGermanov/yolov8_onnx_rust
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2b50c66
Showing
10 changed files
with
1,176 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
[source.crates-io] | ||
replace-with = 'rsproxy' | ||
|
||
[source.rsproxy] | ||
registry = "https://rsproxy.cn/crates.io-index" | ||
|
||
[registries.rsproxy] | ||
index = "https://rsproxy.cn/crates.io-index" | ||
|
||
# 中国科学技术大学 | ||
[source.ustc] | ||
registry = "https://mirrors.ustc.edu.cn/crates.io-index" | ||
|
||
# 上海交通大学 | ||
[source.sjtu] | ||
registry = "https://mirrors.sjtug.sjtu.edu.cn/git/crates.io-index/" | ||
|
||
[net] | ||
git-fetch-with-cli = true | ||
|
||
[build] | ||
rustflags = [ | ||
# "-L", "./libs", | ||
"-L","/opt/homebrew/Cellar/libtorch/2.0.1" | ||
] | ||
|
||
[env] | ||
LIBTORCH = "/opt/homebrew/Cellar/libtorch/2.0.1" | ||
LD_LIBRARY_PATH = "/opt/homebrew/Cellar/libtorch/2.0.1/lib" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
/target | ||
/Cargo.lock | ||
/testdata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"cSpell.words": [ | ||
"Bbox", | ||
"imageproc", | ||
"npreds", | ||
"preprocess" | ||
], | ||
"rust-analyzer.linkedProjects": [ | ||
"./Cargo.toml" | ||
], | ||
"rust-analyzer.showUnlinkedFileNotification": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[package] | ||
name = "yolov8" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
tch = { version = "0.13.0", optional = true } | ||
image = "0.24.0" | ||
rusttype = "0.9.3" | ||
imageproc = "0.23.0" | ||
ndarray = { version = "0.15.6", optional = true } | ||
ort = { version = "1.15.2", optional = true } | ||
|
||
[features] | ||
default = ["onnx"] | ||
onnx = ["ndarray", "ort"] | ||
full = ["tch","onnx"] |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# yolov8 | ||
|
||
使用 yolov8 模型推理, 支持两种方式: | ||
|
||
1. libtorch 推理 | ||
2. onnxruntime 推理 | ||
|
||
## 环境要求 | ||
### libtorch 推理 | ||
1. 安装 libtorch | ||
```shell | ||
# 下载 libtorch | ||
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.8.1%2Bcpu.zip | ||
# 解压 | ||
unzip libtorch-cxx11-abi-shared-with-deps-1.8.1+cpu.zip | ||
# 安装 | ||
cd libtorch | ||
sudo cp lib/* /usr/lib/ | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#[derive(Debug, Clone, Copy)] | ||
pub struct Bbox { | ||
pub xmin: f64, | ||
pub ymin: f64, | ||
pub xmax: f64, | ||
pub ymax: f64, | ||
pub confidence: f64, | ||
pub cls_index: i64, | ||
} | ||
|
||
impl Bbox { | ||
pub fn name(&self, names: &[String]) -> String { | ||
names[self.cls_index as usize].clone() | ||
} | ||
} | ||
|
||
// Function calculates "Intersection-over-union" coefficient for specified two boxes | ||
// https://pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/. | ||
// Returns Intersection over union ratio as a float number | ||
pub fn iou(box1: &Bbox, box2: &Bbox) -> f64 { | ||
intersection(box1, box2) / union(box1, box2) | ||
} | ||
|
||
// Function calculates union area of two boxes | ||
// Returns Area of the boxes union as a float number | ||
pub fn union(box1: &Bbox, box2: &Bbox) -> f64 { | ||
let Bbox { | ||
xmin: box1_x1, | ||
ymin: box1_y1, | ||
xmax: box1_x2, | ||
ymax: box1_y2, | ||
cls_index: _, | ||
confidence: _, | ||
} = *box1; | ||
let Bbox { | ||
xmin: box2_x1, | ||
ymin: box2_y1, | ||
xmax: box2_x2, | ||
ymax: box2_y2, | ||
cls_index: _, | ||
confidence: _, | ||
} = *box2; | ||
let box1_area = (box1_x2 - box1_x1) * (box1_y2 - box1_y1); | ||
let box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1); | ||
box1_area + box2_area - intersection(box1, box2) | ||
} | ||
|
||
// Function calculates intersection area of two boxes | ||
// Returns Area of intersection of the boxes as a float number | ||
pub fn intersection(box1: &Bbox, box2: &Bbox) -> f64 { | ||
let Bbox { | ||
xmin: box1_x1, | ||
ymin: box1_y1, | ||
xmax: box1_x2, | ||
ymax: box1_y2, | ||
cls_index: _, | ||
confidence: _, | ||
} = *box1; | ||
let Bbox { | ||
xmin: box2_x1, | ||
ymin: box2_y1, | ||
xmax: box2_x2, | ||
ymax: box2_y2, | ||
cls_index: _, | ||
confidence: _, | ||
} = *box2; | ||
let x1 = box1_x1.max(box2_x1); | ||
let y1 = box1_y1.max(box2_y1); | ||
let x2 = box1_x2.min(box2_x2); | ||
let y2 = box1_y2.min(box2_y2); | ||
(x2 - x1) * (y2 - y1) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pub mod bbox; | ||
#[cfg(feature = "tch")] | ||
pub mod tch; | ||
#[cfg(feature = "onnx")] | ||
pub mod onnx; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
use std::{ | ||
path::Path, | ||
sync::Arc, time::Instant, | ||
}; | ||
|
||
use crate::bbox::{iou, Bbox}; | ||
use image::{imageops::FilterType, GenericImageView, DynamicImage}; | ||
use ndarray::{s, Array, Axis, IxDyn}; | ||
use ort::{Environment, Session, SessionBuilder,Value}; | ||
|
||
pub struct YOLOv8 { | ||
model: Session, | ||
} | ||
|
||
impl YOLOv8 { | ||
pub fn new<P>(onnx_file: P) -> Result<YOLOv8, ort::OrtError> | ||
where | ||
P: AsRef<Path>, | ||
{ | ||
let env = Arc::new(Environment::builder().with_name("YOLOv8").build()?); | ||
|
||
let model = SessionBuilder::new(&env)?.with_model_from_file(onnx_file)?; | ||
Ok(Self { model }) | ||
} | ||
|
||
pub fn predict(&self, image:DynamicImage) -> Result<Vec<Bbox>, ort::OrtError> { | ||
let (input, img_width, img_height) = self.prepare_input(image); | ||
let start_time = Instant::now(); | ||
let output = self.run_model(input)?; | ||
println!("onnx inference time:{} ms", start_time.elapsed().as_millis()); | ||
let res = self.process_output(output, img_width, img_height); | ||
Ok(res) | ||
} | ||
|
||
fn prepare_input(&self, img: DynamicImage) -> (Array<f32, IxDyn>, u32, u32) { | ||
// let img: image::DynamicImage = image::load_from_memory_with_format(&buf, image::ImageFormat::Jpeg).unwrap(); | ||
let (img_width, img_height) = (img.width(), img.height()); | ||
let img = img.resize_exact(640, 640, FilterType::CatmullRom); | ||
let mut input = Array::zeros((1, 3, 640, 640)).into_dyn(); | ||
for pixel in img.pixels() { | ||
let x = pixel.0 as usize; | ||
let y = pixel.1 as usize; | ||
let [r, g, b, _] = pixel.2 .0; | ||
input[[0, 0, y, x]] = (r as f32) / 255.0; | ||
input[[0, 1, y, x]] = (g as f32) / 255.0; | ||
input[[0, 2, y, x]] = (b as f32) / 255.0; | ||
} | ||
(input, img_width, img_height) | ||
} | ||
fn run_model(&self, input: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ort::OrtError> { | ||
let input_as_values = &input.as_standard_layout(); | ||
let model_inputs = vec![Value::from_array(self.model.allocator(), input_as_values)?]; | ||
let outputs = self.model.run(model_inputs)?; | ||
let output = outputs | ||
.get(0) | ||
.unwrap() | ||
.try_extract::<f32>()? | ||
.view() | ||
.t() | ||
.into_owned(); | ||
Ok(output) | ||
} | ||
|
||
#[allow(clippy::manual_retain)] | ||
fn process_output( | ||
&self, | ||
output: Array<f32, IxDyn>, | ||
img_width: u32, | ||
img_height: u32, | ||
) -> Vec<Bbox> { | ||
let mut boxes = Vec::new(); | ||
let output = output.slice(s![.., .., 0]); | ||
for row in output.axis_iter(Axis(0)) { | ||
let row: Vec<_> = row.iter().copied().collect(); | ||
let (class_id, prob) = row | ||
.iter() | ||
.skip(4) | ||
.enumerate() | ||
.map(|(index, value)| (index, *value)) | ||
.reduce(|accum, row| if row.1 > accum.1 { row } else { accum }) | ||
.unwrap(); | ||
if prob < 0.5 { | ||
continue; | ||
} | ||
// let label = YOLO_CLASSES[class_id]; | ||
let xc = row[0] / 640.0 * (img_width as f32); | ||
let yc = row[1] / 640.0 * (img_height as f32); | ||
let w = row[2] / 640.0 * (img_width as f32); | ||
let h = row[3] / 640.0 * (img_height as f32); | ||
let x1 = xc - w / 2.0; | ||
let x2 = xc + w / 2.0; | ||
let y1 = yc - h / 2.0; | ||
let y2 = yc + h / 2.0; | ||
// (x1,y1,x2,y2,label,prob) | ||
boxes.push(Bbox { | ||
xmin: x1 as f64, | ||
ymin: y1 as f64, | ||
xmax: x2 as f64, | ||
ymax: y2 as f64, | ||
confidence: prob as f64, | ||
cls_index: class_id as i64, | ||
}); | ||
} | ||
|
||
boxes.sort_by(|box1, box2| box2.confidence.total_cmp(&box1.confidence)); | ||
let mut result = Vec::new(); | ||
while !boxes.is_empty() { | ||
result.push(boxes[0]); | ||
boxes = boxes | ||
.iter() | ||
.filter(|box1| iou(&boxes[0], box1) < 0.7) | ||
.copied() | ||
.collect() | ||
} | ||
result | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
|
||
#[test] | ||
fn predict() { | ||
let onnx_file = "testdata/best.onnx"; | ||
let yolo: super::YOLOv8 = super::YOLOv8::new(onnx_file).unwrap(); | ||
let img = include_bytes!("../testdata/testssss.jpg"); | ||
let image = image::load_from_memory_with_format(img, image::ImageFormat::Jpeg).unwrap(); | ||
let res = yolo.predict(image).unwrap(); | ||
dbg!(res); | ||
} | ||
} |
Oops, something went wrong.