Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DETR Example #47

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
[submodule "examples/quantization_aware_training/imagenet1k/deit/deit"]
path = examples/quantization_aware_training/imagenet1k/deit/deit
url = https://github.com/facebookresearch/deit.git
[submodule "examples/post_training_quantization/coco2017/DETR/detr"]
path = examples/post_training_quantization/coco2017/DETR/detr
url = https://github.com/facebookresearch/detr.git
[submodule "examples/quantization_aware_training/coco2017/DETR/detr"]
path = examples/quantization_aware_training/coco2017/DETR/detr
url = https://github.com/facebookresearch/detr.git
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 29 additions & 0 deletions examples/post_training_quantization/coco2017/DETR/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# DETR PTQ example

## preparation

The `DETR` pretrained model is the checkpoint from https://github.com/facebookresearch/detr . The example will automatically download the checkpoint using `torch.hub.load`.

The datasets used in this example are train dataset and validation dataset of COCO2017. They can be downloaded from http://cocodataset.org. also the relative cocoapi should be installed.

## Usage

```shell
python3 main.py qconfig.yaml --coco_path /path/to/coco
```
Since mask is not well supported by onnx, we removed mask-related codes and assign the batch size to be 1 only. Dynamic_axes for onnx is also not supported yet.

## Metrics

|DETR-R50|mAPc|AP50|AP75| remarks|
|-|-|-|-|-|
|float|0.421 | 0.623 | 0.443 | baseline
|8w8f|0.332|0.588|0.320| minmax observer|
|8w8f|0.404|0.612|0.421| minmax observer, float w&f for last 2 bbox embed layers|
|8w8f|0.384|0.598|0.402| minmax observer, apply aciq laplace observer for last bbox embed layer|
|8w8f|0.398|0.609|0.420| minmax observer, apply aciq laplace observer for last 2 bbox embed layer|

TRT DETR w/ fixed input shape, enable int8&fp16 QPS: 118.334 on Nvidia 2080Ti. For detailed visualization, please refer to
```shell
examples/post_training_quantization/coco2017/DETR/DETR_8w8f_visualization_mAP0395.svg
```
1 change: 1 addition & 0 deletions examples/post_training_quantization/coco2017/DETR/detr
Submodule detr added at 8a144f
95 changes: 95 additions & 0 deletions examples/post_training_quantization/coco2017/DETR/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import os

import util.misc as utils
from datasets.coco_eval import CocoEvaluator
from datasets.panoptic_eval import PanopticEvaluator
Copy link
Member

@PeiqinSun PeiqinSun Nov 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are datasets?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sys.path.append("./detr") is added to main.py, so datasets used here is detr.datasets





@torch.no_grad()
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
model.eval()
criterion.eval()

metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Test:'

iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
coco_evaluator = CocoEvaluator(base_ds, iou_types)
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]

panoptic_evaluator = None
if 'panoptic' in postprocessors.keys():
panoptic_evaluator = PanopticEvaluator(
data_loader.dataset.ann_file,
data_loader.dataset.ann_folder,
output_dir=os.path.join(output_dir, "panoptic_eval"),
)

for samples, targets in metric_logger.log_every(data_loader, 10, header):
sample = samples.tensors.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

outputs = model(sample)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict

# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {k: v * weight_dict[k]
for k, v in loss_dict_reduced.items() if k in weight_dict}
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
for k, v in loss_dict_reduced.items()}
metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
**loss_dict_reduced_scaled,
**loss_dict_reduced_unscaled)
metric_logger.update(class_error=loss_dict_reduced['class_error'])

orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['bbox'](outputs, orig_target_sizes)
if 'segm' in postprocessors.keys():
target_sizes = torch.stack([t["size"] for t in targets], dim=0)
results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
res = {target['image_id'].item(): output for target, output in zip(targets, results)}
if coco_evaluator is not None:
coco_evaluator.update(res)

if panoptic_evaluator is not None:
res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
for i, target in enumerate(targets):
image_id = target["image_id"].item()
file_name = f"{image_id:012d}.png"
res_pano[i]["image_id"] = image_id
res_pano[i]["file_name"] = file_name

panoptic_evaluator.update(res_pano)

# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
if coco_evaluator is not None:
coco_evaluator.synchronize_between_processes()
if panoptic_evaluator is not None:
panoptic_evaluator.synchronize_between_processes()

# accumulate predictions from all images
if coco_evaluator is not None:
coco_evaluator.accumulate()
coco_evaluator.summarize()
panoptic_res = None
if panoptic_evaluator is not None:
panoptic_res = panoptic_evaluator.summarize()
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if coco_evaluator is not None:
if 'bbox' in postprocessors.keys():
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
if 'segm' in postprocessors.keys():
stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
if panoptic_res is not None:
stats['PQ_all'] = panoptic_res["All"]
stats['PQ_th'] = panoptic_res["Things"]
stats['PQ_st'] = panoptic_res["Stuff"]
return stats, coco_evaluator
177 changes: 177 additions & 0 deletions examples/post_training_quantization/coco2017/DETR/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import argparse
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import detr.util.misc as utils
import sys
sys.path.append("./detr")
from detr.datasets import get_coco_api_from_dataset
from val_transform_datasets import build_dataset
from model import build
import onnx
import onnx_graphsurgeon as gs

from sparsebit.quantization import QuantModel, parse_qconfig

from evaluation import evaluate

parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
parser.add_argument("qconfig", help="the path of quant config")
parser.add_argument(
"-a",
"--arch",
metavar="ARCH",
default="deit_tiny_patch16_224",
help="ViT model architecture. (default: deit_tiny)",
)
parser.add_argument(
"-j",
"--num_workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 4)",
)
parser.add_argument(
"-b",
"--batch-size",
default=1,
type=int,
metavar="N",
help="mini-batch size (default: 64), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"-p",
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)

# * Backbone
parser.add_argument('--backbone', default='resnet50', type=str,
help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features")


# * Transformer
parser.add_argument('--enc_layers', default=6, type=int,
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_queries', default=100, type=int,
help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true')

# Loss
parser.add_argument('--aux_loss', dest='aux_loss', action='store_true',
help="Enables auxiliary decoding losses (loss at each layer)")
# * Matcher
parser.add_argument('--set_cost_class', default=1, type=float,
help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_bbox', default=5, type=float,
help="L1 box coefficient in the matching cost")
parser.add_argument('--set_cost_giou', default=2, type=float,
help="giou box coefficient in the matching cost")
# * Loss coefficients
parser.add_argument('--mask_loss_coef', default=1, type=float)
parser.add_argument('--dice_loss_coef', default=1, type=float)
parser.add_argument('--bbox_loss_coef', default=5, type=float)
parser.add_argument('--giou_loss_coef', default=2, type=float)
parser.add_argument('--eos_coef', default=0.1, type=float,
help="Relative classification weight of the no-object class")

#configs for coco dataset
parser.add_argument('--dataset_file', default='coco')
parser.add_argument('--coco_path', type=str)
parser.add_argument('--masks', action='store_true',
help="Train segmentation head if the flag is provided")
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')

parser.add_argument('--device', default='cuda',
help='device to use for training / testing')

def main():
args = parser.parse_args()
device = args.device

# get pretrained model from https://github.com/facebookresearch/detr
model = torch.hub.load('facebookresearch/detr:main', 'detr_resnet50', pretrained=True)
model, criterion, postprocessors = build(args, model)

qconfig = parse_qconfig(args.qconfig)
qmodel = QuantModel(model, config=qconfig).to(device)

cudnn.benchmark = True

dataset_val = build_dataset(image_set='val', args=args)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_val = torch.utils.data.DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
base_ds = get_coco_api_from_dataset(dataset_val)

dataset_calib = build_dataset(image_set='train', args=args)
sampler_calib = torch.utils.data.RandomSampler(dataset_calib)
data_loader_calib = torch.utils.data.DataLoader(dataset_calib, args.batch_size, sampler=sampler_calib,
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)


qmodel.eval()
with torch.no_grad():
qmodel.prepare_calibration()
# forward calibration-set
calibration_size = 16
cur_size = 0
for samples, _ in data_loader_calib:
sample = samples.tensors.to(device)
qmodel(sample)
cur_size += args.batch_size
if cur_size >= calibration_size:
break
qmodel.calc_qparams()
qmodel.set_quant(w_quant=True, a_quant=True)

test_stats, coco_evaluator = evaluate(qmodel, criterion, postprocessors,
data_loader_val, base_ds, device, args.output_dir)

qmodel.export_onnx(torch.randn(1, 3, 800, 1200), name="qDETR.onnx")

# graph = gs.import_onnx(onnx.load("qDETR.onnx"))
# Reshapes = [node for node in graph.nodes if node.op == "Reshape"]
# for node in Reshapes:
# if isinstance(node.inputs[1], gs.Constant):
# if node.inputs[1].values[1]==7600:
# node.inputs[1].values[1] = 8
# elif node.inputs[1].values[1]==950:
# node.inputs[1].values[1] = 1
# elif node.inputs[1].values[1]==100:
# node.inputs[1].values[1] = 1
# elif node.inputs[1].values[1]==800:
# node.inputs[1].values[1] = 8

# onnx.save(gs.export_onnx(graph), "qDETR.onnx")




if __name__ == "__main__":
main()
Loading