diff --git a/config/model/gdl_segformer.yaml b/config/model/gdl_segformer.yaml new file mode 100644 index 00000000..e481fb25 --- /dev/null +++ b/config/model/gdl_segformer.yaml @@ -0,0 +1,4 @@ +# @package _global_ +model: + _target_: models.segformer.SegFormer + encoder: "mit_b2" \ No newline at end of file diff --git a/models/segformer.py b/models/segformer.py new file mode 100644 index 00000000..a61868da --- /dev/null +++ b/models/segformer.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import segmentation_models_pytorch as smp + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class Decoder(nn.Module): + def __init__(self, encoder="mit_b2", + in_channels=[64, 128, 320, 512], + feature_strides=[4, 8, 16, 32], + embedding_dim=768, + num_classes=1, dropout_ratio=0.1): + super(Decoder, self).__init__() + if encoder == "mit_b0": + in_channels = [32, 64, 160, 256] + if encoder == "mit_b0" or "mit_b1": + embedding_dim = 256 + assert len(feature_strides) == len(in_channels) + assert min(feature_strides) == feature_strides[0] + + self.num_classes = num_classes + self.in_channels = in_channels + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + + self.linear_fuse = nn.Sequential( + nn.Conv2d(in_channels=embedding_dim * 4, out_channels=embedding_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(embedding_dim), nn.ReLU(inplace=True)) + self.dropout = nn.Dropout2d(dropout_ratio) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(input=_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(input=_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(input=_c2, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class SegFormer(nn.Module): + def __init__(self, encoder, in_channels, classes) -> None: + super().__init__() + self.encoder = smp.encoders.get_encoder(name=encoder, in_channels=in_channels, depth=5, drop_path_rate=0.1) + self.decoder = Decoder(encoder=encoder, num_classes=classes) + + def forward(self, img): + x = self.encoder(img)[2:] + x = self.decoder(x) + x = F.interpolate(input=x, size=img.shape[2:], scale_factor=None, mode='bilinear', align_corners=False) + return x \ No newline at end of file diff --git a/tests/model/test_models.py b/tests/model/test_models.py index 73f5c176..05db8129 100644 --- a/tests/model/test_models.py +++ b/tests/model/test_models.py @@ -26,11 +26,11 @@ def test_net(self) -> None: hconf = HydraConfig() hconf.set_config(cfg) del cfg.loss.is_binary # prevent exception at instantiation - rand_img = torch.rand((2, 4, 64, 64)) + rand_img = torch.rand((2, 3, 64, 64)) print(cfg.model._target_) model = define_model_architecture( net_params=cfg.model, - in_channels=4, + in_channels=3, out_classes=4, ) output = model(rand_img) @@ -41,7 +41,7 @@ class TestReadCheckpoint(object): """ Tests reading a checkpoint saved outside GDL into memory """ - var = 4 + var = 3 dummy_model = models.unet.UNetSmall(classes=var, in_channels=var) dummy_optimizer = instantiate({'_target_': 'torch.optim.Adam'}, params=dummy_model.parameters()) filename = "test.pth.tar" @@ -80,7 +80,7 @@ class TestDefineModelMultigpu(object): """ Tests defining model architecture with weights from provided checkpoint and pushing to multiple devices if possible """ - dummy_model = unet.UNet(4, 4, True, 0.5) + dummy_model = unet.UNet(4, 3, True, 0.5) filename = "test.pth.tar" torch.save(dummy_model.state_dict(), filename) @@ -92,7 +92,7 @@ class TestDefineModelMultigpu(object): checkpoint = read_checkpoint(filename) model = define_model( net_params={'_target_': 'models.unet.UNet'}, - in_channels=4, + in_channels=3, out_classes=4, main_device=device, devices=list(gpu_devices_dict.keys()), diff --git a/tests/test_tiling_segmentation.py b/tests/test_tiling_segmentation.py index f03b0bf1..18478aec 100644 --- a/tests/test_tiling_segmentation.py +++ b/tests/test_tiling_segmentation.py @@ -170,14 +170,20 @@ def test_tiling_segmentation_parallel(self): } cfg = DictConfig(cfg) tiling(cfg) - out_labels = [ - (Path(f"{data_dir}/{proj}/trn/23322E759967N_clipped_1m_1of2/labels_burned"), (80, 95)), - (Path(f"{data_dir}/{proj}/val/23322E759967N_clipped_1m_1of2/labels_burned"), (5, 20)), - (Path(f"{data_dir}/{proj}/tst/23322E759967N_clipped_1m_2of2/labels_burned"), (170, 190)), - ] - for labels_burned_dir, lbls_nb in out_labels: - # exact number may vary because of random sort between "trn" and "val" - assert lbls_nb[0] <= len(list(labels_burned_dir.iterdir())) <= lbls_nb[1] + trn_labels = list(Path(f"{data_dir}/{proj}/trn/").glob("*/labels_burned/*.tif")) + val_labels = list(Path(f"{data_dir}/{proj}/val/").glob("*/labels_burned/*.tif")) + tst_labels = list(Path(f"{data_dir}/{proj}/tst/").glob("*/labels_burned/*.tif")) + assert len(trn_labels) > 0 + assert len(val_labels) > 0 + assert len(tst_labels) > 0 + + patch_size = cfg.tiling.patch_size + for label_list in [trn_labels, val_labels, tst_labels]: + num_tifs_to_check = min(5, len(label_list)) + for tif_file in label_list[:num_tifs_to_check]: + with rasterio.open(tif_file) as src: + width, height = src.width, src.height + assert width == patch_size and height == patch_size shutil.rmtree(Path(data_dir) / proj) def test_tiling_inference(self):