-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DETR Example #47
Open
Jiang-Stan
wants to merge
2
commits into
megvii-research:main
Choose a base branch
from
Jiang-Stan:support_detr_example
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add DETR Example #47
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,9 @@ | ||
[submodule "examples/quantization_aware_training/imagenet1k/deit/deit"] | ||
path = examples/quantization_aware_training/imagenet1k/deit/deit | ||
url = https://github.com/facebookresearch/deit.git | ||
[submodule "examples/post_training_quantization/coco2017/DETR/detr"] | ||
path = examples/post_training_quantization/coco2017/DETR/detr | ||
url = https://github.com/facebookresearch/detr.git | ||
[submodule "examples/quantization_aware_training/coco2017/DETR/detr"] | ||
path = examples/quantization_aware_training/coco2017/DETR/detr | ||
url = https://github.com/facebookresearch/detr.git |
4,001 changes: 4,001 additions & 0 deletions
4,001
...es/post_training_quantization/coco2017/DETR/DETR_8w8f_visualization_mAP0399.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 29 additions & 0 deletions
29
examples/post_training_quantization/coco2017/DETR/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# DETR PTQ example | ||
|
||
## preparation | ||
|
||
The `DETR` pretrained model is the checkpoint from https://github.com/facebookresearch/detr . The example will automatically download the checkpoint using `torch.hub.load`. | ||
|
||
The datasets used in this example are train dataset and validation dataset of COCO2017. They can be downloaded from http://cocodataset.org. also the relative cocoapi should be installed. | ||
|
||
## Usage | ||
|
||
```shell | ||
python3 main.py qconfig.yaml --coco_path /path/to/coco | ||
``` | ||
Since mask is not well supported by onnx, we removed mask-related codes and assign the batch size to be 1 only. Dynamic_axes for onnx is also not supported yet. | ||
|
||
## Metrics | ||
|
||
|DETR-R50|mAPc|AP50|AP75| remarks| | ||
|-|-|-|-|-| | ||
|float|0.421 | 0.623 | 0.443 | baseline | ||
|8w8f|0.332|0.588|0.320| minmax observer| | ||
|8w8f|0.404|0.612|0.421| minmax observer, float w&f for last 2 bbox embed layers| | ||
|8w8f|0.384|0.598|0.402| minmax observer, apply aciq laplace observer for last bbox embed layer| | ||
|8w8f|0.398|0.609|0.420| minmax observer, apply aciq laplace observer for last 2 bbox embed layer| | ||
|
||
TRT DETR w/ fixed input shape, enable int8&fp16 QPS: 118.334 on Nvidia 2080Ti. For detailed visualization, please refer to | ||
```shell | ||
examples/post_training_quantization/coco2017/DETR/DETR_8w8f_visualization_mAP0395.svg | ||
``` |
95 changes: 95 additions & 0 deletions
95
examples/post_training_quantization/coco2017/DETR/evaluation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import torch | ||
import os | ||
|
||
import util.misc as utils | ||
from datasets.coco_eval import CocoEvaluator | ||
from datasets.panoptic_eval import PanopticEvaluator | ||
|
||
|
||
|
||
|
||
@torch.no_grad() | ||
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir): | ||
model.eval() | ||
criterion.eval() | ||
|
||
metric_logger = utils.MetricLogger(delimiter=" ") | ||
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) | ||
header = 'Test:' | ||
|
||
iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) | ||
coco_evaluator = CocoEvaluator(base_ds, iou_types) | ||
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] | ||
|
||
panoptic_evaluator = None | ||
if 'panoptic' in postprocessors.keys(): | ||
panoptic_evaluator = PanopticEvaluator( | ||
data_loader.dataset.ann_file, | ||
data_loader.dataset.ann_folder, | ||
output_dir=os.path.join(output_dir, "panoptic_eval"), | ||
) | ||
|
||
for samples, targets in metric_logger.log_every(data_loader, 10, header): | ||
sample = samples.tensors.to(device) | ||
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | ||
|
||
outputs = model(sample) | ||
loss_dict = criterion(outputs, targets) | ||
weight_dict = criterion.weight_dict | ||
|
||
# reduce losses over all GPUs for logging purposes | ||
loss_dict_reduced = utils.reduce_dict(loss_dict) | ||
loss_dict_reduced_scaled = {k: v * weight_dict[k] | ||
for k, v in loss_dict_reduced.items() if k in weight_dict} | ||
loss_dict_reduced_unscaled = {f'{k}_unscaled': v | ||
for k, v in loss_dict_reduced.items()} | ||
metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), | ||
**loss_dict_reduced_scaled, | ||
**loss_dict_reduced_unscaled) | ||
metric_logger.update(class_error=loss_dict_reduced['class_error']) | ||
|
||
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) | ||
results = postprocessors['bbox'](outputs, orig_target_sizes) | ||
if 'segm' in postprocessors.keys(): | ||
target_sizes = torch.stack([t["size"] for t in targets], dim=0) | ||
results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) | ||
res = {target['image_id'].item(): output for target, output in zip(targets, results)} | ||
if coco_evaluator is not None: | ||
coco_evaluator.update(res) | ||
|
||
if panoptic_evaluator is not None: | ||
res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) | ||
for i, target in enumerate(targets): | ||
image_id = target["image_id"].item() | ||
file_name = f"{image_id:012d}.png" | ||
res_pano[i]["image_id"] = image_id | ||
res_pano[i]["file_name"] = file_name | ||
|
||
panoptic_evaluator.update(res_pano) | ||
|
||
# gather the stats from all processes | ||
metric_logger.synchronize_between_processes() | ||
print("Averaged stats:", metric_logger) | ||
if coco_evaluator is not None: | ||
coco_evaluator.synchronize_between_processes() | ||
if panoptic_evaluator is not None: | ||
panoptic_evaluator.synchronize_between_processes() | ||
|
||
# accumulate predictions from all images | ||
if coco_evaluator is not None: | ||
coco_evaluator.accumulate() | ||
coco_evaluator.summarize() | ||
panoptic_res = None | ||
if panoptic_evaluator is not None: | ||
panoptic_res = panoptic_evaluator.summarize() | ||
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} | ||
if coco_evaluator is not None: | ||
if 'bbox' in postprocessors.keys(): | ||
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() | ||
if 'segm' in postprocessors.keys(): | ||
stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() | ||
if panoptic_res is not None: | ||
stats['PQ_all'] = panoptic_res["All"] | ||
stats['PQ_th'] = panoptic_res["Things"] | ||
stats['PQ_st'] = panoptic_res["Stuff"] | ||
return stats, coco_evaluator |
177 changes: 177 additions & 0 deletions
177
examples/post_training_quantization/coco2017/DETR/main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import argparse | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.parallel | ||
import torch.backends.cudnn as cudnn | ||
import torch.optim | ||
import torch.utils.data | ||
import torch.utils.data.distributed | ||
import detr.util.misc as utils | ||
import sys | ||
sys.path.append("./detr") | ||
from detr.datasets import get_coco_api_from_dataset | ||
from val_transform_datasets import build_dataset | ||
from model import build | ||
import onnx | ||
import onnx_graphsurgeon as gs | ||
|
||
from sparsebit.quantization import QuantModel, parse_qconfig | ||
|
||
from evaluation import evaluate | ||
|
||
parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") | ||
parser.add_argument("qconfig", help="the path of quant config") | ||
parser.add_argument( | ||
"-a", | ||
"--arch", | ||
metavar="ARCH", | ||
default="deit_tiny_patch16_224", | ||
help="ViT model architecture. (default: deit_tiny)", | ||
) | ||
parser.add_argument( | ||
"-j", | ||
"--num_workers", | ||
default=2, | ||
type=int, | ||
metavar="N", | ||
help="number of data loading workers (default: 4)", | ||
) | ||
parser.add_argument( | ||
"-b", | ||
"--batch-size", | ||
default=1, | ||
type=int, | ||
metavar="N", | ||
help="mini-batch size (default: 64), this is the total " | ||
"batch size of all GPUs on the current node when " | ||
"using Data Parallel or Distributed Data Parallel", | ||
) | ||
parser.add_argument( | ||
"-p", | ||
"--print-freq", | ||
default=10, | ||
type=int, | ||
metavar="N", | ||
help="print frequency (default: 10)", | ||
) | ||
|
||
# * Backbone | ||
parser.add_argument('--backbone', default='resnet50', type=str, | ||
help="Name of the convolutional backbone to use") | ||
parser.add_argument('--dilation', action='store_true', | ||
help="If true, we replace stride with dilation in the last convolutional block (DC5)") | ||
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), | ||
help="Type of positional embedding to use on top of the image features") | ||
|
||
|
||
# * Transformer | ||
parser.add_argument('--enc_layers', default=6, type=int, | ||
help="Number of encoding layers in the transformer") | ||
parser.add_argument('--dec_layers', default=6, type=int, | ||
help="Number of decoding layers in the transformer") | ||
parser.add_argument('--dim_feedforward', default=2048, type=int, | ||
help="Intermediate size of the feedforward layers in the transformer blocks") | ||
parser.add_argument('--hidden_dim', default=256, type=int, | ||
help="Size of the embeddings (dimension of the transformer)") | ||
parser.add_argument('--dropout', default=0.1, type=float, | ||
help="Dropout applied in the transformer") | ||
parser.add_argument('--nheads', default=8, type=int, | ||
help="Number of attention heads inside the transformer's attentions") | ||
parser.add_argument('--num_queries', default=100, type=int, | ||
help="Number of query slots") | ||
parser.add_argument('--pre_norm', action='store_true') | ||
|
||
# Loss | ||
parser.add_argument('--aux_loss', dest='aux_loss', action='store_true', | ||
help="Enables auxiliary decoding losses (loss at each layer)") | ||
# * Matcher | ||
parser.add_argument('--set_cost_class', default=1, type=float, | ||
help="Class coefficient in the matching cost") | ||
parser.add_argument('--set_cost_bbox', default=5, type=float, | ||
help="L1 box coefficient in the matching cost") | ||
parser.add_argument('--set_cost_giou', default=2, type=float, | ||
help="giou box coefficient in the matching cost") | ||
# * Loss coefficients | ||
parser.add_argument('--mask_loss_coef', default=1, type=float) | ||
parser.add_argument('--dice_loss_coef', default=1, type=float) | ||
parser.add_argument('--bbox_loss_coef', default=5, type=float) | ||
parser.add_argument('--giou_loss_coef', default=2, type=float) | ||
parser.add_argument('--eos_coef', default=0.1, type=float, | ||
help="Relative classification weight of the no-object class") | ||
|
||
#configs for coco dataset | ||
parser.add_argument('--dataset_file', default='coco') | ||
parser.add_argument('--coco_path', type=str) | ||
parser.add_argument('--masks', action='store_true', | ||
help="Train segmentation head if the flag is provided") | ||
parser.add_argument('--output_dir', default='', | ||
help='path where to save, empty for no saving') | ||
|
||
parser.add_argument('--device', default='cuda', | ||
help='device to use for training / testing') | ||
|
||
def main(): | ||
args = parser.parse_args() | ||
device = args.device | ||
|
||
# get pretrained model from https://github.com/facebookresearch/detr | ||
model = torch.hub.load('facebookresearch/detr:main', 'detr_resnet50', pretrained=True) | ||
model, criterion, postprocessors = build(args, model) | ||
|
||
qconfig = parse_qconfig(args.qconfig) | ||
qmodel = QuantModel(model, config=qconfig).to(device) | ||
|
||
cudnn.benchmark = True | ||
|
||
dataset_val = build_dataset(image_set='val', args=args) | ||
sampler_val = torch.utils.data.SequentialSampler(dataset_val) | ||
data_loader_val = torch.utils.data.DataLoader(dataset_val, args.batch_size, sampler=sampler_val, | ||
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) | ||
base_ds = get_coco_api_from_dataset(dataset_val) | ||
|
||
dataset_calib = build_dataset(image_set='train', args=args) | ||
sampler_calib = torch.utils.data.RandomSampler(dataset_calib) | ||
data_loader_calib = torch.utils.data.DataLoader(dataset_calib, args.batch_size, sampler=sampler_calib, | ||
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) | ||
|
||
|
||
qmodel.eval() | ||
with torch.no_grad(): | ||
qmodel.prepare_calibration() | ||
# forward calibration-set | ||
calibration_size = 16 | ||
cur_size = 0 | ||
for samples, _ in data_loader_calib: | ||
sample = samples.tensors.to(device) | ||
qmodel(sample) | ||
cur_size += args.batch_size | ||
if cur_size >= calibration_size: | ||
break | ||
qmodel.calc_qparams() | ||
qmodel.set_quant(w_quant=True, a_quant=True) | ||
|
||
test_stats, coco_evaluator = evaluate(qmodel, criterion, postprocessors, | ||
data_loader_val, base_ds, device, args.output_dir) | ||
|
||
qmodel.export_onnx(torch.randn(1, 3, 800, 1200), name="qDETR.onnx") | ||
|
||
# graph = gs.import_onnx(onnx.load("qDETR.onnx")) | ||
# Reshapes = [node for node in graph.nodes if node.op == "Reshape"] | ||
# for node in Reshapes: | ||
# if isinstance(node.inputs[1], gs.Constant): | ||
# if node.inputs[1].values[1]==7600: | ||
# node.inputs[1].values[1] = 8 | ||
# elif node.inputs[1].values[1]==950: | ||
# node.inputs[1].values[1] = 1 | ||
# elif node.inputs[1].values[1]==100: | ||
# node.inputs[1].values[1] = 1 | ||
# elif node.inputs[1].values[1]==800: | ||
# node.inputs[1].values[1] = 8 | ||
|
||
# onnx.save(gs.export_onnx(graph), "qDETR.onnx") | ||
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where are datasets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sys.path.append("./detr")
is added to main.py, sodatasets
used here isdetr.datasets