Skip to content

Commit

Permalink
Makes Pretrained consistent (#47)
Browse files Browse the repository at this point in the history
* makes pretrained consistent

* fixes tests
  • Loading branch information
oke-aditya authored Nov 18, 2020
1 parent 06dc0ac commit 8f6fd6b
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 14 deletions.
2 changes: 1 addition & 1 deletion quickvision/models/detection/detr/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def forward(self, images):
return self.model(images)


def create_vision_detr(num_classes, num_queries, backbone):
def create_vision_detr(num_classes: int, num_queries: int, backbone):
"""
Creates Detr Model for Object Detection
Args:
Expand Down
9 changes: 4 additions & 5 deletions quickvision/models/detection/faster_rcnn/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,23 @@ class lit_frcnn(pl.LightningModule):
"""

def __init__(self, learning_rate: float = 0.0001, num_classes: int = 91,
pretrained: bool = False, backbone: str = None, fpn: bool = True,
pretrained_backbone: bool = True, trainable_backbone_layers: int = 3,
backbone: str = None, fpn: bool = True,
pretrained_backbone: str = None, trainable_backbone_layers: int = 3,
**kwargs, ):
"""
Args:
learning_rate: the learning rate
num_classes: number of detection classes (including background)
pretrained: if true, returns a model pre-trained on COCO train2017
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
pretrained_backbone (str): if "imagenet", returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block
"""
super().__init__()
self.learning_rate = learning_rate
self.num_classes = num_classes
self.backbone = backbone
if backbone is None:
self.model = fasterrcnn_resnet50_fpn(pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
self.model = fasterrcnn_resnet50_fpn(pretrained=True,
trainable_backbone_layers=trainable_backbone_layers,)

in_features = self.model.roi_heads.box_predictor.cls_score.in_features
Expand Down
11 changes: 5 additions & 6 deletions quickvision/models/detection/retinanet/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,23 @@ class lit_retinanet(pl.LightningModule):
"""

def __init__(self, learning_rate: float = 0.0001, num_classes: int = 91,
pretrained: bool = False, backbone: str = None, fpn: bool = True,
pretrained_backbone: bool = True, trainable_backbone_layers: int = 3,
replace_head: bool = True, **kwargs, ):
backbone: str = None, fpn: bool = True,
pretrained_backbone: str = None, trainable_backbone_layers: int = 3,
**kwargs, ):
"""
Args:
learning_rate: the learning rate
num_classes: number of detection classes (including background)
pretrained: if true, returns a model pre-trained on COCO train2017
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
pretrained_backbone (str): if "imagenet", returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block
"""
super().__init__()
self.learning_rate = learning_rate
self.num_classes = num_classes
self.backbone = backbone
if backbone is None:
self.model = retinanet_resnet50_fpn(pretrained=pretrained,
pretrained_backbone=pretrained_backbone, **kwargs)
self.model = retinanet_resnet50_fpn(pretrained=True, **kwargs)

self.model.head = RetinaNetHead(in_channels=self.model.backbone.out_channels,
num_anchors=self.model.head.classification_head.num_anchors,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_frcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_lit_cnn_cuda(self):
self.assertTrue(flag)

def test_lit_forward(self):
model = faster_rcnn.lit_frcnn(num_classes=3, pretrained=None, pretrained_backbone=False)
model = faster_rcnn.lit_frcnn(num_classes=3, pretrained_backbone=False)
image = torch.rand(1, 3, 400, 400)
out = model(image)
self.assertIsInstance(out, list)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_lit_cnn_cuda(self):
self.assertTrue(flag)

def test_lit_forward(self):
model = retinanet.lit_retinanet(num_classes=3, pretrained=None, pretrained_backbone=False)
model = retinanet.lit_retinanet(num_classes=3, pretrained_backbone=False)
image = torch.rand(1, 3, 400, 400)
out = model(image)
self.assertIsInstance(out, list)
Expand Down

0 comments on commit 8f6fd6b

Please sign in to comment.