Skip to content

Commit

Permalink
add tch
Browse files Browse the repository at this point in the history
  • Loading branch information
kingzcheung committed Aug 16, 2023
0 parents commit 2b50c66
Show file tree
Hide file tree
Showing 10 changed files with 1,176 additions and 0 deletions.
29 changes: 29 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[source.crates-io]
replace-with = 'rsproxy'

[source.rsproxy]
registry = "https://rsproxy.cn/crates.io-index"

[registries.rsproxy]
index = "https://rsproxy.cn/crates.io-index"

# 中国科学技术大学
[source.ustc]
registry = "https://mirrors.ustc.edu.cn/crates.io-index"

# 上海交通大学
[source.sjtu]
registry = "https://mirrors.sjtug.sjtu.edu.cn/git/crates.io-index/"

[net]
git-fetch-with-cli = true

[build]
rustflags = [
# "-L", "./libs",
"-L","/opt/homebrew/Cellar/libtorch/2.0.1"
]

[env]
LIBTORCH = "/opt/homebrew/Cellar/libtorch/2.0.1"
LD_LIBRARY_PATH = "/opt/homebrew/Cellar/libtorch/2.0.1/lib"
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
/target
/Cargo.lock
/testdata
12 changes: 12 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"cSpell.words": [
"Bbox",
"imageproc",
"npreds",
"preprocess"
],
"rust-analyzer.linkedProjects": [
"./Cargo.toml"
],
"rust-analyzer.showUnlinkedFileNotification": false
}
19 changes: 19 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "yolov8"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tch = { version = "0.13.0", optional = true }
image = "0.24.0"
rusttype = "0.9.3"
imageproc = "0.23.0"
ndarray = { version = "0.15.6", optional = true }
ort = { version = "1.15.2", optional = true }

[features]
default = ["onnx"]
onnx = ["ndarray", "ort"]
full = ["tch","onnx"]
674 changes: 674 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# yolov8

使用 yolov8 模型推理, 支持两种方式:

1. libtorch 推理
2. onnxruntime 推理

## 环境要求
### libtorch 推理
1. 安装 libtorch
```shell
# 下载 libtorch
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.8.1%2Bcpu.zip
# 解压
unzip libtorch-cxx11-abi-shared-with-deps-1.8.1+cpu.zip
# 安装
cd libtorch
sudo cp lib/* /usr/lib/
```
72 changes: 72 additions & 0 deletions src/bbox.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#[derive(Debug, Clone, Copy)]
pub struct Bbox {
pub xmin: f64,
pub ymin: f64,
pub xmax: f64,
pub ymax: f64,
pub confidence: f64,
pub cls_index: i64,
}

impl Bbox {
pub fn name(&self, names: &[String]) -> String {
names[self.cls_index as usize].clone()
}
}

// Function calculates "Intersection-over-union" coefficient for specified two boxes
// https://pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/.
// Returns Intersection over union ratio as a float number
pub fn iou(box1: &Bbox, box2: &Bbox) -> f64 {
intersection(box1, box2) / union(box1, box2)
}

// Function calculates union area of two boxes
// Returns Area of the boxes union as a float number
pub fn union(box1: &Bbox, box2: &Bbox) -> f64 {
let Bbox {
xmin: box1_x1,
ymin: box1_y1,
xmax: box1_x2,
ymax: box1_y2,
cls_index: _,
confidence: _,
} = *box1;
let Bbox {
xmin: box2_x1,
ymin: box2_y1,
xmax: box2_x2,
ymax: box2_y2,
cls_index: _,
confidence: _,
} = *box2;
let box1_area = (box1_x2 - box1_x1) * (box1_y2 - box1_y1);
let box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1);
box1_area + box2_area - intersection(box1, box2)
}

// Function calculates intersection area of two boxes
// Returns Area of intersection of the boxes as a float number
pub fn intersection(box1: &Bbox, box2: &Bbox) -> f64 {
let Bbox {
xmin: box1_x1,
ymin: box1_y1,
xmax: box1_x2,
ymax: box1_y2,
cls_index: _,
confidence: _,
} = *box1;
let Bbox {
xmin: box2_x1,
ymin: box2_y1,
xmax: box2_x2,
ymax: box2_y2,
cls_index: _,
confidence: _,
} = *box2;
let x1 = box1_x1.max(box2_x1);
let y1 = box1_y1.max(box2_y1);
let x2 = box1_x2.min(box2_x2);
let y2 = box1_y2.min(box2_y2);
(x2 - x1) * (y2 - y1)
}
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub mod bbox;
#[cfg(feature = "tch")]
pub mod tch;
#[cfg(feature = "onnx")]
pub mod onnx;
131 changes: 131 additions & 0 deletions src/onnx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use std::{
path::Path,
sync::Arc, time::Instant,
};

use crate::bbox::{iou, Bbox};
use image::{imageops::FilterType, GenericImageView, DynamicImage};
use ndarray::{s, Array, Axis, IxDyn};
use ort::{Environment, Session, SessionBuilder,Value};

pub struct YOLOv8 {
model: Session,
}

impl YOLOv8 {
pub fn new<P>(onnx_file: P) -> Result<YOLOv8, ort::OrtError>
where
P: AsRef<Path>,
{
let env = Arc::new(Environment::builder().with_name("YOLOv8").build()?);

let model = SessionBuilder::new(&env)?.with_model_from_file(onnx_file)?;
Ok(Self { model })
}

pub fn predict(&self, image:DynamicImage) -> Result<Vec<Bbox>, ort::OrtError> {
let (input, img_width, img_height) = self.prepare_input(image);
let start_time = Instant::now();
let output = self.run_model(input)?;
println!("onnx inference time:{} ms", start_time.elapsed().as_millis());
let res = self.process_output(output, img_width, img_height);
Ok(res)
}

fn prepare_input(&self, img: DynamicImage) -> (Array<f32, IxDyn>, u32, u32) {
// let img: image::DynamicImage = image::load_from_memory_with_format(&buf, image::ImageFormat::Jpeg).unwrap();
let (img_width, img_height) = (img.width(), img.height());
let img = img.resize_exact(640, 640, FilterType::CatmullRom);
let mut input = Array::zeros((1, 3, 640, 640)).into_dyn();
for pixel in img.pixels() {
let x = pixel.0 as usize;
let y = pixel.1 as usize;
let [r, g, b, _] = pixel.2 .0;
input[[0, 0, y, x]] = (r as f32) / 255.0;
input[[0, 1, y, x]] = (g as f32) / 255.0;
input[[0, 2, y, x]] = (b as f32) / 255.0;
}
(input, img_width, img_height)
}
fn run_model(&self, input: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ort::OrtError> {
let input_as_values = &input.as_standard_layout();
let model_inputs = vec![Value::from_array(self.model.allocator(), input_as_values)?];
let outputs = self.model.run(model_inputs)?;
let output = outputs
.get(0)
.unwrap()
.try_extract::<f32>()?
.view()
.t()
.into_owned();
Ok(output)
}

#[allow(clippy::manual_retain)]
fn process_output(
&self,
output: Array<f32, IxDyn>,
img_width: u32,
img_height: u32,
) -> Vec<Bbox> {
let mut boxes = Vec::new();
let output = output.slice(s![.., .., 0]);
for row in output.axis_iter(Axis(0)) {
let row: Vec<_> = row.iter().copied().collect();
let (class_id, prob) = row
.iter()
.skip(4)
.enumerate()
.map(|(index, value)| (index, *value))
.reduce(|accum, row| if row.1 > accum.1 { row } else { accum })
.unwrap();
if prob < 0.5 {
continue;
}
// let label = YOLO_CLASSES[class_id];
let xc = row[0] / 640.0 * (img_width as f32);
let yc = row[1] / 640.0 * (img_height as f32);
let w = row[2] / 640.0 * (img_width as f32);
let h = row[3] / 640.0 * (img_height as f32);
let x1 = xc - w / 2.0;
let x2 = xc + w / 2.0;
let y1 = yc - h / 2.0;
let y2 = yc + h / 2.0;
// (x1,y1,x2,y2,label,prob)
boxes.push(Bbox {
xmin: x1 as f64,
ymin: y1 as f64,
xmax: x2 as f64,
ymax: y2 as f64,
confidence: prob as f64,
cls_index: class_id as i64,
});
}

boxes.sort_by(|box1, box2| box2.confidence.total_cmp(&box1.confidence));
let mut result = Vec::new();
while !boxes.is_empty() {
result.push(boxes[0]);
boxes = boxes
.iter()
.filter(|box1| iou(&boxes[0], box1) < 0.7)
.copied()
.collect()
}
result
}
}

#[cfg(test)]
mod test {

#[test]
fn predict() {
let onnx_file = "testdata/best.onnx";
let yolo: super::YOLOv8 = super::YOLOv8::new(onnx_file).unwrap();
let img = include_bytes!("../testdata/testssss.jpg");
let image = image::load_from_memory_with_format(img, image::ImageFormat::Jpeg).unwrap();
let res = yolo.predict(image).unwrap();
dbg!(res);
}
}
Loading

0 comments on commit 2b50c66

Please sign in to comment.