-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored main branch #1
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,8 +45,7 @@ def get_export_import_copy(m): | |
buffer = io.BytesIO() | ||
torch.jit.save(m, buffer) | ||
buffer.seek(0) | ||
imported = torch.jit.load(buffer) | ||
return imported | ||
return torch.jit.load(buffer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
m_import = get_export_import_copy(m) | ||
with freeze_rng_state(): | ||
|
@@ -98,40 +97,33 @@ class TestModel: | |
@staticmethod | ||
def _get_in_channels(width_multiple, use_p6): | ||
grow_widths = [256, 512, 768, 1024] if use_p6 else [256, 512, 1024] | ||
in_channels = [int(gw * width_multiple) for gw in grow_widths] | ||
return in_channels | ||
return [int(gw * width_multiple) for gw in grow_widths] | ||
Comment on lines
-101
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
@staticmethod | ||
def _get_strides(use_p6: bool): | ||
if use_p6: | ||
strides = [8, 16, 32, 64] | ||
else: | ||
strides = [8, 16, 32] | ||
return strides | ||
return [8, 16, 32, 64] if use_p6 else [8, 16, 32] | ||
Comment on lines
-106
to
+104
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
@staticmethod | ||
def _get_anchor_grids(use_p6: bool): | ||
if use_p6: | ||
anchor_grids = [ | ||
return ( | ||
[ | ||
[19, 27, 44, 40, 38, 94], | ||
[96, 68, 86, 152, 180, 137], | ||
[140, 301, 303, 264, 238, 542], | ||
[436, 615, 739, 380, 925, 792], | ||
] | ||
else: | ||
anchor_grids = [ | ||
if use_p6 | ||
else [ | ||
[10, 13, 16, 30, 33, 23], | ||
[30, 61, 62, 45, 59, 119], | ||
[116, 90, 156, 198, 373, 326], | ||
] | ||
return anchor_grids | ||
) | ||
Comment on lines
-114
to
+121
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def _compute_anchors(self, height, width, use_p6: bool): | ||
strides = self._get_strides(use_p6) | ||
anchors_num = len(strides) | ||
anchors_shape = [] | ||
for s in strides: | ||
anchors_shape.append((height // s, width // s)) | ||
anchors_shape = [(height // s, width // s) for s in strides] | ||
Comment on lines
-132
to
+126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return anchors_num, anchors_shape | ||
|
||
def _get_feature_shapes(self, height, width, width_multiple=0.5, use_p6=False): | ||
|
@@ -147,8 +139,7 @@ def _get_feature_maps(self, batch_size, height, width, width_multiple=0.5, use_p | |
width_multiple=width_multiple, | ||
use_p6=use_p6, | ||
) | ||
feature_maps = [torch.rand(batch_size, *f_shape) for f_shape in feature_shapes] | ||
return feature_maps | ||
return [torch.rand(batch_size, *f_shape) for f_shape in feature_shapes] | ||
Comment on lines
-150
to
+142
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def _get_head_outputs(self, batch_size, height, width, width_multiple=0.5, use_p6=False): | ||
feature_shapes = self._get_feature_shapes( | ||
|
@@ -160,9 +151,7 @@ def _get_head_outputs(self, batch_size, height, width, width_multiple=0.5, use_p | |
|
||
num_outputs = self.num_outputs | ||
head_shapes = [(batch_size, 3, *f_shape[1:], num_outputs) for f_shape in feature_shapes] | ||
head_outputs = [torch.rand(*h_shape) for h_shape in head_shapes] | ||
|
||
return head_outputs | ||
return [torch.rand(*h_shape) for h_shape in head_shapes] | ||
Comment on lines
-163
to
+154
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def _init_test_backbone_with_pan( | ||
self, | ||
|
@@ -176,14 +165,13 @@ def _init_test_backbone_with_pan( | |
backbone_name = f"darknet_{model_size}_{version.replace('.', '_')}" | ||
backbone_arch = eval(f"darknet_{'tan' if use_tan else 'pan'}_backbone") | ||
assert backbone_arch in [darknet_pan_backbone, darknet_tan_backbone] | ||
model = backbone_arch( | ||
return backbone_arch( | ||
backbone_name, | ||
depth_multiple, | ||
width_multiple, | ||
version=version, | ||
use_p6=use_p6, | ||
) | ||
return model | ||
Comment on lines
-179
to
-186
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
@pytest.mark.parametrize( | ||
"depth_multiple, width_multiple, version, use_p6, use_tan", | ||
|
@@ -226,8 +214,7 @@ def test_backbone_with_pan( | |
def _init_test_anchor_generator(self, use_p6=False): | ||
strides = self._get_strides(use_p6) | ||
anchor_grids = self._get_anchor_grids(use_p6) | ||
anchor_generator = AnchorGenerator(strides, anchor_grids) | ||
return anchor_generator | ||
return AnchorGenerator(strides, anchor_grids) | ||
Comment on lines
-229
to
+217
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
@pytest.mark.parametrize( | ||
"width_multiple, use_p6", | ||
|
@@ -256,8 +243,7 @@ def _init_test_yolo_head(self, width_multiple=0.5, use_p6=False): | |
num_anchors = len(strides) | ||
num_classes = self.num_classes | ||
|
||
box_head = YOLOHead(in_channels, num_anchors, strides, num_classes) | ||
return box_head | ||
return YOLOHead(in_channels, num_anchors, strides, num_classes) | ||
Comment on lines
-259
to
+246
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def test_yolo_head(self): | ||
N, H, W = 4, 416, 352 | ||
|
@@ -277,8 +263,7 @@ def _init_test_postprocessors(self, strides): | |
score_thresh = 0.5 | ||
nms_thresh = 0.45 | ||
detections_per_img = 100 | ||
postprocessors = PostProcess(strides, score_thresh, nms_thresh, detections_per_img) | ||
return postprocessors | ||
return PostProcess(strides, score_thresh, nms_thresh, detections_per_img) | ||
Comment on lines
-280
to
+266
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
@pytest.mark.parametrize("use_p6", [False, True]) | ||
def test_postprocessors(self, use_p6): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,8 +8,7 @@ class TestAnchorGenerator: | |
|
||
def get_features(self, images): | ||
s0, s1 = images.shape[-2:] | ||
features = [torch.rand(2, 8, s0 // 5, s1 // 5)] | ||
return features | ||
return [torch.rand(2, 8, s0 // 5, s1 // 5)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def test_anchor_generator(self): | ||
images = torch.rand(2, 3, 10, 10) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ def run_model(self, model, inputs_list): | |
# validate the exported model with onnx runtime | ||
for test_inputs in inputs_list: | ||
with torch.no_grad(): | ||
if isinstance(test_inputs, Tensor) or isinstance(test_inputs, list): | ||
if isinstance(test_inputs, (Tensor, list)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
test_inputs = (test_inputs,) | ||
test_outputs = model(*test_inputs) | ||
if isinstance(test_outputs, Tensor): | ||
|
@@ -55,15 +55,13 @@ def ort_validate(self, onnx_io, inputs, outputs): | |
# Inference on ONNX Runtime | ||
ort_outs = y_runtime.predict(inputs) | ||
|
||
for i in range(0, len(outputs)): | ||
for i in range(len(outputs)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05) | ||
|
||
def get_image(self, img_name): | ||
|
||
img_path = Path(__file__).parent.resolve() / "assets" / img_name | ||
image = read_image(str(img_path)) / 255 | ||
|
||
return image | ||
return read_image(str(img_path)) / 255 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def get_test_images(self): | ||
return self.get_image("bus.jpg"), self.get_image("zidane.jpg") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,7 +168,7 @@ def evaluate(model, data_loader, coco_evaluator, device, print_freq, use_wandb, | |
) | ||
header = "Test:" | ||
for images, targets in metric_logger.log_every(data_loader, print_freq, header): | ||
images = list(image.to(device) for image in images) | ||
images = [image.to(device) for image in images] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
if torch.cuda.is_available(): | ||
torch.cuda.synchronize() | ||
|
@@ -185,8 +185,7 @@ def evaluate(model, data_loader, coco_evaluator, device, print_freq, use_wandb, | |
|
||
# gather the stats from all processes | ||
metric_logger.synchronize_between_processes() | ||
results = coco_evaluator.compute() | ||
return results | ||
return coco_evaluator.compute() | ||
|
||
|
||
def cli_main(): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,8 +113,7 @@ def __init__(self, message, exc=None): | |
|
||
def run_clang_format_diff_wrapper(args, file): | ||
try: | ||
ret = run_clang_format_diff(args, file) | ||
return ret | ||
return run_clang_format_diff(args, file) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
except DiffError: | ||
raise | ||
except Exception as e: | ||
|
@@ -293,7 +292,7 @@ def main(): | |
colored_stdout = sys.stdout.isatty() | ||
colored_stderr = sys.stderr.isatty() | ||
|
||
version_invocation = [args.clang_format_executable, str("--version")] | ||
version_invocation = [args.clang_format_executable, "--version"] | ||
Comment on lines
-296
to
+295
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
try: | ||
subprocess.check_call(version_invocation, stdout=DEVNULL) | ||
except subprocess.CalledProcessError as e: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,15 +24,14 @@ def create_small_table(small_dict): | |
str: the table as a string. | ||
""" | ||
keys, values = tuple(zip(*small_dict.items())) | ||
table = tabulate( | ||
return tabulate( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
[values], | ||
headers=keys, | ||
tablefmt="pipe", | ||
floatfmt=".3f", | ||
stralign="center", | ||
numalign="center", | ||
) | ||
return table | ||
|
||
|
||
def get_coco_api_from_dataset(dataset): | ||
|
@@ -65,8 +64,8 @@ def prepare_coco128( | |
data_path.mkdir(parents=True, exist_ok=True) | ||
|
||
zip_path = data_path / "coco128.zip" | ||
coco128_url = "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip" | ||
if not zip_path.is_file(): | ||
coco128_url = "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
logger.info(f"Downloading coco128 datasets form {coco128_url}") | ||
torch.hub.download_url_to_file(coco128_url, zip_path, hash_prefix="a67d2887") | ||
|
||
|
@@ -106,13 +105,11 @@ def get_dataloader(data_root: str, mode: str = "val", batch_size: int = 4): | |
# We adopt the sequential sampler in order to repeat the experiment | ||
sampler = torch.utils.data.SequentialSampler(dataset) | ||
|
||
loader = torch.utils.data.DataLoader( | ||
return torch.utils.data.DataLoader( | ||
Comment on lines
-109
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
dataset, | ||
batch_size, | ||
sampler=sampler, | ||
drop_last=False, | ||
collate_fn=collate_fn, | ||
num_workers=0, | ||
) | ||
|
||
return loader |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,8 +64,7 @@ def __call__(self, image, target): | |
if anno and "keypoints" in anno[0]: | ||
keypoints = [obj["keypoints"] for obj in anno] | ||
keypoints = torch.as_tensor(keypoints, dtype=torch.float32) | ||
num_keypoints = keypoints.shape[0] | ||
if num_keypoints: | ||
if num_keypoints := keypoints.shape[0]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
keypoints = keypoints.view(num_keypoints, -1, 3) | ||
|
||
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | ||
|
@@ -76,9 +75,7 @@ def __call__(self, image, target): | |
if keypoints is not None: | ||
keypoints = keypoints[keep] | ||
|
||
target = {} | ||
target["boxes"] = boxes | ||
target["labels"] = classes | ||
target = {"boxes": boxes, "labels": classes} | ||
if self.return_masks: | ||
target["masks"] = masks | ||
target["image_id"] = image_id | ||
|
@@ -109,7 +106,6 @@ def convert_coco_poly_to_mask(segmentations, height, width): | |
mask = mask.any(dim=2) | ||
masks.append(mask) | ||
if masks: | ||
masks = torch.stack(masks, dim=0) | ||
return torch.stack(masks, dim=0) | ||
else: | ||
masks = torch.zeros((0, height, width), dtype=torch.uint8) | ||
return masks | ||
return torch.zeros((0, height, width), dtype=torch.uint8) | ||
Comment on lines
-112
to
+111
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,7 @@ def __init__( | |
dist_sync_fn=dist_sync_fn, | ||
) | ||
self._logger = logging.getLogger(__name__) | ||
if isinstance(coco_gt, str) or isinstance(coco_gt, PosixPath): | ||
if isinstance(coco_gt, (str, PosixPath)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
with contextlib.redirect_stdout(io.StringIO()): | ||
coco_gt = COCO(coco_gt) | ||
elif isinstance(coco_gt, COCO): | ||
|
@@ -70,10 +70,10 @@ def __init__( | |
raise NotImplementedError(f"Currently not supports type {type(coco_gt)}") | ||
|
||
self.coco_gt = coco_gt | ||
if eval_type == "yolov5": | ||
self.category_id_maps = coco_gt.getCatIds() | ||
elif eval_type == "torchvision": | ||
if eval_type == "torchvision": | ||
self.category_id_maps = list(range(coco_gt.getCatIds()[-1] + 1)) | ||
elif eval_type == "yolov5": | ||
self.category_id_maps = coco_gt.getCatIds() | ||
else: | ||
raise NotImplementedError(f"Currently not supports eval type {eval_type}") | ||
|
||
|
@@ -116,8 +116,7 @@ def compute(self): | |
# Summarize | ||
coco_eval.summarize() | ||
|
||
results = self.derive_coco_results() | ||
return results | ||
return self.derive_coco_results() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def derive_coco_results(self, class_names: Optional[List[str]] = None): | ||
""" | ||
|
@@ -149,7 +148,10 @@ def derive_coco_results(self, class_names: Optional[List[str]] = None): | |
metric: float(self.coco_eval.stats[idx] * 100 if self.coco_eval.stats[idx] >= 0 else "nan") | ||
for idx, metric in enumerate(metrics) | ||
} | ||
self._logger.info(f"Evaluation results for {self.iou_type}:\n" + create_small_table(results)) | ||
self._logger.info( | ||
f"Evaluation results for {self.iou_type}:\n{create_small_table(results)}" | ||
) | ||
|
||
Comment on lines
-152
to
+154
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
if not np.isfinite(sum(results.values())): | ||
self._logger.info("Some metrics cannot be computed and is shown as NaN.") | ||
|
@@ -181,9 +183,9 @@ def derive_coco_results(self, class_names: Optional[List[str]] = None): | |
headers=["category", "AP"] * (N_COLS // 2), | ||
numalign="left", | ||
) | ||
self._logger.info(f"Per-category {self.iou_type} AP:\n" + table) | ||
self._logger.info(f"Per-category {self.iou_type} AP:\n{table}") | ||
|
||
results.update({"AP-" + name: ap for name, ap in results_per_category}) | ||
results |= {f"AP-{name}": ap for name, ap in results_per_category} | ||
return results | ||
|
||
def prepare(self, predictions, iou_type): | ||
|
@@ -229,10 +231,7 @@ def merge(img_ids, eval_imgs): | |
for p in all_img_ids: | ||
merged_img_ids.extend(p) | ||
|
||
merged_eval_imgs = [] | ||
for p in all_eval_imgs: | ||
merged_eval_imgs.append(p) | ||
|
||
merged_eval_imgs = list(all_eval_imgs) | ||
Comment on lines
-232
to
+234
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
merged_img_ids = np.array(merged_img_ids) | ||
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) | ||
|
||
|
@@ -285,7 +284,7 @@ def evaluate(self): | |
# loop through images, area range, max detection number | ||
catIds = p.catIds if p.useCats else [-1] | ||
|
||
if p.iouType == "segm" or p.iouType == "bbox": | ||
if p.iouType in ["segm", "bbox"]: | ||
Comment on lines
-288
to
+287
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
computeIoU = self.computeIoU | ||
elif p.iouType == "keypoints": | ||
computeIoU = self.computeOks | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines
30-30
refactored with the following changes:use-fstring-for-concatenation
)