Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PyTorch1.11 and OpenCV-4.5.5 #460

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
FROM nvcr.io/nvidia/pytorch:21.10-py3
FROM nvcr.io/nvidia/pytorch:22.04-py3
#FROM nvcr.io/nvidia/pytorch:21.04-py3

# Install linux packages
RUN apt update && apt install -y zip htop screen libgl1-mesa-glx
RUN apt update && apt install -y zip htop tmux libgl1-mesa-glx

# Install python dependencies
COPY requirements.txt .
RUN python -m pip install --upgrade pip
RUN pip uninstall -y nvidia-tensorboard nvidia-tensorboard-plugin-dlprof
RUN pip install --no-cache -r requirements.txt coremltools onnx gsutil notebook wandb>=0.12.2
RUN pip install --no-cache -U torch torchvision numpy Pillow
# RUN pip install --no-cache torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
# RUN pip uninstall -y nvidia-tensorboard nvidia-tensorboard-plugin-dlprof
RUN pip install --no-cache -r requirements.txt
RUN pip install --no-cache coremltools onnx gsutil notebook wandb>=0.12.2
RUN pip install --no-cache -U numpy Pillow
RUN pip install --no-cache torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
#RUN pip install --no-cache torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
RUN pip install --no-cache roboflow shapely

# Create working directory
RUN mkdir -p /usr/src/app
Expand Down
16 changes: 8 additions & 8 deletions docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
## Requirements
* Linux **(Recommend)**, Windows **(not Recommend, Please refer to this [issue](https://github.com/hukaixuan19970627/yolov5_obb/issues/224) if you have difficulty in generating utils/nms_rotated_ext.cpython-XX-XX-XX-XX.so)**
* Python 3.7+
* PyTorch ≥ 1.7
* CUDA 9.0 or higher
* PyTorch ≥ 1.11
* CUDA 11.3 or higher

I have tested the following versions of OS and softwares:
* OS:Ubuntu 16.04/18.04
* CUDA: 10.0/10.1/10.2/11.3
* OS:Ubuntu 20.04
* CUDA: 11.3/11.6

## Install
**CUDA Driver Version ≥ CUDA Toolkit Version(runtime version) = torch.version.cuda**

a. Create a conda virtual environment and activate it, e.g.,
```
conda create -n Py39_Torch1.10_cu11.3 python=3.9 -y
source activate Py39_Torch1.10_cu11.3
conda create -n Py39_Torch1.11_cu11.3 python=3.9 -y
source activate Py39_Torch1.11_cu11.3
```
b. Make sure your CUDA runtime api version ≤ CUDA driver version. (for example 11.3 ≤ 11.4)
```
Expand All @@ -24,7 +24,7 @@ nvidia-smi
```
c. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/), Make sure cudatoolkit version same as CUDA runtime api version, e.g.,
```
pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
pip3 install --no-cache torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
nvcc -V
python
>>> import torch
Expand All @@ -33,7 +33,7 @@ python
```
d. Clone the yolov5-obb repository.
```
git clone https://github.com/hukaixuan19970627/yolov5_obb.git
git clone https://github.com/ohashi/yolov5_obb.git
cd yolov5_obb
```
e. Install yolov5-obb.
Expand Down
8 changes: 5 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
# Base ----------------------------------------
matplotlib>=3.2.2
numpy>=1.18.5
opencv-python>=4.5.4
opencv-python==4.5.5.64
opencv-contrib-python==4.5.5.64
opencv-python-headless==4.5.5.64
Pillow>=7.1.2
PyYAML>=5.3.1
requests>=2.23.0
scipy>=1.4.1
torch>=1.7.0
torchvision>=0.8.1
#torch>=1.7.0
#torchvision>=0.8.1
tqdm>=4.41.0

# Logging -------------------------------------
Expand Down
16 changes: 8 additions & 8 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def build_targets(self, p, targets):
# ttheta, tgaussian_theta = [], []
tgaussian_theta = []
# gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
feature_wh = torch.ones(2, device=targets.device) # feature_wh
feature_wh = torch.ones(2, device=targets.device).long() # feature_wh
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
# targets (tensor): (n_gt_all_batch, c) -> (na, n_gt_all_batch, c) -> (na, n_gt_all_batch, c+1)
# targets (tensor): (na, n_gt_all_batch, [img_index, clsid, cx, cy, l, s, theta, gaussian_θ_labels, anchor_index]])
Expand All @@ -226,7 +226,7 @@ def build_targets(self, p, targets):
for i in range(self.nl):
anchors = self.anchors[i]
# gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain=[1, 1, w, h, w, h, 1, 1]
feature_wh[0:2] = torch.tensor(p[i].shape)[[3, 2]] # xyxy gain=[w_f, h_f]
feature_wh[0:2] = torch.tensor(p[i].shape)[[3, 2]].long() # xyxy gain=[w_f, h_f]

# Match targets to anchors
# t = targets * gain # xywh featuremap pixel
Expand All @@ -240,9 +240,9 @@ def build_targets(self, p, targets):
t = t[j] # filter; Tensor.size(n_filter1, c+1)

# Offsets
gxy = t[:, 2:4] # grid xy; (n_filter1, 2)
gxy = t[:, 2:4].long() # grid xy; (n_filter1, 2)
# gxi = gain[[2, 3]] - gxy # inverse
gxi = feature_wh[[0, 1]] - gxy # inverse
gxi = (feature_wh[[0, 1]] - gxy).long() # inverse
j, k = ((gxy % 1 < g) & (gxy > 1)).T
l, m = ((gxi % 1 < g) & (gxi > 1)).T
j = torch.stack((torch.ones_like(j), j, k, l, m)) # (5, n_filter1)
Expand All @@ -254,17 +254,17 @@ def build_targets(self, p, targets):

# Define, t (tensor): (n_filter2, [img_index, clsid, cx, cy, l, s, theta, gaussian_θ_labels, anchor_index])
b, c = t[:, :2].long().T # image, class; (n_filter2)
gxy = t[:, 2:4] # grid xy
gwh = t[:, 4:6] # grid wh
gxy = t[:, 2:4].long() # grid xy
gwh = t[:, 4:6].long() # grid wh
# theta = t[:, 6]
gaussian_theta_labels = t[:, 7:-1]
gaussian_theta_labels = t[:, 7:-1].long()
gij = (gxy - offsets).long()
gi, gj = gij.T # grid xy indices

# Append
a = t[:, -1].long() # anchor indices 取整
# indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices
indices.append((b, a, gj.clamp_(0, feature_wh[1] - 1), gi.clamp_(0, feature_wh[0] - 1))) # image, anchor, grid indices
indices.append((b, a, gj.clamp_(0, feature_wh[1] - 1).long(), gi.clamp_(0, feature_wh[0] - 1).long())) # image, anchor, grid indices
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
anch.append(anchors[a]) # anchors
tcls.append(c) # class
Expand Down
21 changes: 10 additions & 11 deletions utils/nms_rotated/src/poly_nms_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#include <ATen/ATen.h>
#include <ATen/ceil_div.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
#include <c10/cuda/CUDACachingAllocator.h>

#include <vector>
#include <iostream>
Expand Down Expand Up @@ -188,7 +187,7 @@ __global__ void poly_nms_kernel(const int n_polys, const float nms_overlap_thres
t |= 1ULL << i;
}
}
const int col_blocks = THCCeilDiv(n_polys, threadsPerBlock);
const int col_blocks = at::ceil_div(n_polys, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
Expand All @@ -206,26 +205,26 @@ at::Tensor poly_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {

int boxes_num = boxes.size(0);

const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
const int col_blocks = at::ceil_div(boxes_num, threadsPerBlock);

scalar_t* boxes_dev = boxes_sorted.data_ptr<scalar_t>();

THCState *state = at::globalContext().lazyInitCUDA();
at::globalContext().lazyInitCUDA();

unsigned long long* mask_dev = NULL;

mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc(boxes_num * col_blocks * sizeof(unsigned long long));

dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
THCCeilDiv(boxes_num, threadsPerBlock));
dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock),
at::ceil_div(boxes_num, threadsPerBlock));
dim3 threads(threadsPerBlock);
poly_nms_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(boxes_num,
nms_overlap_thresh,
boxes_dev,
mask_dev);

std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
THCudaCheck(cudaMemcpyAsync(
C10_CUDA_CHECK(cudaMemcpyAsync(
&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
Expand Down Expand Up @@ -253,7 +252,7 @@ at::Tensor poly_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
}
}

THCudaFree(state, mask_dev);
c10::cuda::CUDACachingAllocator::raw_delete(mask_dev);

return order_t.index({
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
Expand Down