diff --git a/eval.py b/eval.py index b4a60ad7..68311a1a 100644 --- a/eval.py +++ b/eval.py @@ -116,10 +116,9 @@ def main(cfg, gpu): arch=cfg.MODEL.arch_decoder.lower(), fc_dim=cfg.MODEL.fc_dim, num_class=cfg.DATASET.num_class, - weights=cfg.MODEL.weights_decoder, - use_softmax=True) + weights=cfg.MODEL.weights_decoder) - crit = nn.NLLLoss(ignore_index=-1) + crit = nn.CrossEntropyLoss(ignore_index=-1) segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) diff --git a/eval_multipro.py b/eval_multipro.py index db328cd3..e133115b 100644 --- a/eval_multipro.py +++ b/eval_multipro.py @@ -106,10 +106,9 @@ def worker(cfg, gpu_id, start_idx, end_idx, result_queue): arch=cfg.MODEL.arch_decoder.lower(), fc_dim=cfg.MODEL.fc_dim, num_class=cfg.DATASET.num_class, - weights=cfg.MODEL.weights_decoder, - use_softmax=True) + weights=cfg.MODEL.weights_decoder) - crit = nn.NLLLoss(ignore_index=-1) + crit = nn.CrossEntropyLoss(ignore_index=-1) segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) diff --git a/mit_semseg/models/models.py b/mit_semseg/models/models.py index 793d2bd7..77e3d4a0 100644 --- a/mit_semseg/models/models.py +++ b/mit_semseg/models/models.py @@ -29,22 +29,19 @@ def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None): def forward(self, feed_dict, *, segSize=None): # training if segSize is None: - if self.deep_sup_scale is not None: # use deep supervision technique - (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) - else: - pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) + pred_dict = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) - loss = self.crit(pred, feed_dict['seg_label']) + loss = self.crit(pred_dict['logits'], feed_dict['seg_label']) if self.deep_sup_scale is not None: - loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label']) + loss_deepsup = self.crit(pred_dict['deepsup_logits'], feed_dict['seg_label']) loss = loss + loss_deepsup * self.deep_sup_scale - acc = self.pixel_acc(pred, feed_dict['seg_label']) + acc = self.pixel_acc(pred_dict['logits'], feed_dict['seg_label']) return loss, acc # inference else: - pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize) - return pred + pred_dict = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize) + return pred_dict['logits'] class ModelBuilder: @@ -112,39 +109,33 @@ def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''): @staticmethod def build_decoder(arch='ppm_deepsup', fc_dim=512, num_class=150, - weights='', use_softmax=False): + weights=''): arch = arch.lower() if arch == 'c1_deepsup': net_decoder = C1DeepSup( num_class=num_class, - fc_dim=fc_dim, - use_softmax=use_softmax) + fc_dim=fc_dim) elif arch == 'c1': net_decoder = C1( num_class=num_class, - fc_dim=fc_dim, - use_softmax=use_softmax) + fc_dim=fc_dim) elif arch == 'ppm': net_decoder = PPM( num_class=num_class, - fc_dim=fc_dim, - use_softmax=use_softmax) + fc_dim=fc_dim) elif arch == 'ppm_deepsup': net_decoder = PPMDeepsup( num_class=num_class, - fc_dim=fc_dim, - use_softmax=use_softmax) + fc_dim=fc_dim) elif arch == 'upernet_lite': net_decoder = UPerNet( num_class=num_class, fc_dim=fc_dim, - use_softmax=use_softmax, fpn_dim=256) elif arch == 'upernet': net_decoder = UPerNet( num_class=num_class, fc_dim=fc_dim, - use_softmax=use_softmax, fpn_dim=512) else: raise Exception('Architecture undefined!') @@ -325,10 +316,8 @@ def forward(self, x, return_feature_maps=False): # last conv, deep supervision class C1DeepSup(nn.Module): - def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): + def __init__(self, num_class=150, fc_dim=2048): super(C1DeepSup, self).__init__() - self.use_softmax = use_softmax - self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) @@ -342,29 +331,26 @@ def forward(self, conv_out, segSize=None): x = self.cbr(conv5) x = self.conv_last(x) - if self.use_softmax: # is True during inference + if segSize is not None: x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) - x = nn.functional.softmax(x, dim=1) - return x # deep sup conv4 = conv_out[-2] - _ = self.cbr_deepsup(conv4) - _ = self.conv_last_deepsup(_) + deepsup = self.cbr_deepsup(conv4) + deepsup = self.conv_last_deepsup(deepsup) - x = nn.functional.log_softmax(x, dim=1) - _ = nn.functional.log_softmax(_, dim=1) + if segSize is not None: + deepsup = nn.functional.interpolate( + deepsup, size=segSize, mode='bilinear', align_corners=False) - return (x, _) + return dict(logits=x, deepsup_logits=deepsup) # last conv class C1(nn.Module): - def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): + def __init__(self, num_class=150, fc_dim=2048): super(C1, self).__init__() - self.use_softmax = use_softmax - self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) # last conv @@ -375,23 +361,18 @@ def forward(self, conv_out, segSize=None): x = self.cbr(conv5) x = self.conv_last(x) - if self.use_softmax: # is True during inference + if segSize is not None: x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) - x = nn.functional.softmax(x, dim=1) - else: - x = nn.functional.log_softmax(x, dim=1) - return x + return dict(logits=x) # pyramid pooling class PPM(nn.Module): def __init__(self, num_class=150, fc_dim=4096, - use_softmax=False, pool_scales=(1, 2, 3, 6)): + pool_scales=(1, 2, 3, 6)): super(PPM, self).__init__() - self.use_softmax = use_softmax - self.ppm = [] for scale in pool_scales: self.ppm.append(nn.Sequential( @@ -425,22 +406,18 @@ def forward(self, conv_out, segSize=None): x = self.conv_last(ppm_out) - if self.use_softmax: # is True during inference + if segSize is not None: x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) - x = nn.functional.softmax(x, dim=1) - else: - x = nn.functional.log_softmax(x, dim=1) - return x + + return dict(logits=x) # pyramid pooling, deep supervision class PPMDeepsup(nn.Module): def __init__(self, num_class=150, fc_dim=4096, - use_softmax=False, pool_scales=(1, 2, 3, 6)): + pool_scales=(1, 2, 3, 6)): super(PPMDeepsup, self).__init__() - self.use_softmax = use_softmax - self.ppm = [] for scale in pool_scales: self.ppm.append(nn.Sequential( @@ -477,31 +454,29 @@ def forward(self, conv_out, segSize=None): x = self.conv_last(ppm_out) - if self.use_softmax: # is True during inference + if segSize is not None: x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) - x = nn.functional.softmax(x, dim=1) - return x # deep sup conv4 = conv_out[-2] - _ = self.cbr_deepsup(conv4) - _ = self.dropout_deepsup(_) - _ = self.conv_last_deepsup(_) + deepsup = self.cbr_deepsup(conv4) + deepsup = self.dropout_deepsup(deepsup) + deepsup = self.conv_last_deepsup(deepsup) - x = nn.functional.log_softmax(x, dim=1) - _ = nn.functional.log_softmax(_, dim=1) + if segSize is not None: + deepsup = nn.functional.interpolate( + deepsup, size=segSize, mode='bilinear', align_corners=False) - return (x, _) + return dict(logits=x, deepsup_logits=deepsup) # upernet class UPerNet(nn.Module): def __init__(self, num_class=150, fc_dim=4096, - use_softmax=False, pool_scales=(1, 2, 3, 6), + pool_scales=(1, 2, 3, 6), fpn_inplanes=(256, 512, 1024, 2048), fpn_dim=256): super(UPerNet, self).__init__() - self.use_softmax = use_softmax # PPM Module self.ppm_pooling = [] @@ -575,12 +550,8 @@ def forward(self, conv_out, segSize=None): fusion_out = torch.cat(fusion_list, 1) x = self.conv_last(fusion_out) - if self.use_softmax: # is True during inference + if segSize is not None: x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) - x = nn.functional.softmax(x, dim=1) - return x - - x = nn.functional.log_softmax(x, dim=1) - return x + return dict(logits=x) diff --git a/test.py b/test.py index a0a2eec2..494c4e09 100644 --- a/test.py +++ b/test.py @@ -103,10 +103,9 @@ def main(cfg, gpu): arch=cfg.MODEL.arch_decoder, fc_dim=cfg.MODEL.fc_dim, num_class=cfg.DATASET.num_class, - weights=cfg.MODEL.weights_decoder, - use_softmax=True) + weights=cfg.MODEL.weights_decoder) - crit = nn.NLLLoss(ignore_index=-1) + crit = nn.CrossEntropyLoss(ignore_index=-1) segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) diff --git a/train.py b/train.py index cfe3b3ee..a602a7fa 100644 --- a/train.py +++ b/train.py @@ -151,7 +151,7 @@ def main(cfg, gpus): num_class=cfg.DATASET.num_class, weights=cfg.MODEL.weights_decoder) - crit = nn.NLLLoss(ignore_index=-1) + crit = nn.CrossEntropyLoss(ignore_index=-1) if cfg.MODEL.arch_decoder.endswith('deepsup'): segmentation_module = SegmentationModule(