From 1f5d98ce1724b4d4e1787a9a17f7cce5d1146d2a Mon Sep 17 00:00:00 2001 From: Zachary-66 <1032903371@qq.com> Date: Tue, 1 Nov 2022 21:20:23 +0800 Subject: [PATCH] add flow1d --- configs/_base_/models/flow1d.py | 51 ++++ .../flow1d_8xb2_100k_flyingchairs-368x496.py | 5 + mmflow/datasets/transforms/transforms.py | 2 +- mmflow/models/decoders/__init__.py | 3 +- mmflow/models/decoders/flow1d_decoder.py | 258 ++++++++++++++++++ mmflow/models/flow_estimators/__init__.py | 3 +- mmflow/models/flow_estimators/flow1d.py | 21 ++ mmflow/models/utils/__init__.py | 7 +- mmflow/models/utils/attention1d.py | 119 ++++++++ mmflow/models/utils/corr_lookup.py | 91 ++++++ mmflow/models/utils/correlation1d.py | 44 +++ .../test_decoders/test_flow1d_decoder.py | 77 ++++++ tests/test_models/test_flow_estimators.py | 5 +- .../test_utils/test_corr_lookup.py | 25 ++ 14 files changed, 704 insertions(+), 7 deletions(-) create mode 100644 configs/_base_/models/flow1d.py create mode 100644 configs/flow1d/flow1d_8xb2_100k_flyingchairs-368x496.py create mode 100644 mmflow/models/decoders/flow1d_decoder.py create mode 100644 mmflow/models/flow_estimators/flow1d.py create mode 100644 mmflow/models/utils/attention1d.py create mode 100644 mmflow/models/utils/correlation1d.py create mode 100644 tests/test_models/test_decoders/test_flow1d_decoder.py diff --git a/configs/_base_/models/flow1d.py b/configs/_base_/models/flow1d.py new file mode 100644 index 00000000..395cf649 --- /dev/null +++ b/configs/_base_/models/flow1d.py @@ -0,0 +1,51 @@ +model = dict( + type='Flow1D', + data_preprocessor=dict( + type='FlowDataPreprocessor', + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + bgr_to_rgb=True), + radius=32, + cxt_channels=128, + h_channels=128, + encoder=dict( + type='RAFTEncoder', + in_channels=3, + out_channels=256, + net_type='Basic', + norm_cfg=dict(type='IN'), + init_cfg=[ + dict( + type='Kaiming', + layer=['Conv2d'], + mode='fan_out', + nonlinearity='relu'), + dict(type='Constant', layer=['InstanceNorm2d'], val=1, bias=0) + ]), + cxt_encoder=dict( + type='RAFTEncoder', + in_channels=3, + out_channels=256, + net_type='Basic', + norm_cfg=dict(type='SyncBN'), + init_cfg=[ + dict( + type='Kaiming', + layer=['Conv2d'], + mode='fan_out', + nonlinearity='relu'), + dict(type='Constant', layer=['SyncBatchNorm2d'], val=1, bias=0) + ]), + decoder=dict( + type='Flow1DDecoder', + net_type='Basic', + radius=32, + iters=24, + corr_op_cfg=dict(type='CorrLookupFlow1D'), + gru_type='SeqConv', + flow_loss=dict(type='SequenceLoss'), + act_cfg=dict(type='ReLU')), + freeze_bn=False, + train_cfg=dict(), + test_cfg=dict(), +) diff --git a/configs/flow1d/flow1d_8xb2_100k_flyingchairs-368x496.py b/configs/flow1d/flow1d_8xb2_100k_flyingchairs-368x496.py new file mode 100644 index 00000000..7369a706 --- /dev/null +++ b/configs/flow1d/flow1d_8xb2_100k_flyingchairs-368x496.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/flow1d.py', + '../_base_/datasets/flyingchairs_raft_368x496.py', + '../_base_/schedules/raft_100k.py', '../_base_/default_runtime.py' +] diff --git a/mmflow/datasets/transforms/transforms.py b/mmflow/datasets/transforms/transforms.py index ec9c036f..dede8c4b 100644 --- a/mmflow/datasets/transforms/transforms.py +++ b/mmflow/datasets/transforms/transforms.py @@ -583,7 +583,7 @@ def _pad_img(self, results: dict) -> None: elif self.position == 'right': self._pad = [[pad_h // 2, pad_h - pad_h // 2], [pad_w, 0]] elif self.position == 'top': - self._pad = [[0, pad_h, pad_w // 2], [pad_w - pad_w // 2]] + self._pad = [[0, pad_h], [pad_w // 2, pad_w - pad_w // 2]] elif self.position == 'down': self._pad = [[pad_h, 0], [pad_w // 2, pad_w - pad_w // 2]] if len(img1.shape) > 2: diff --git a/mmflow/models/decoders/__init__.py b/mmflow/models/decoders/__init__.py index 7c210fd9..5fb99bb7 100644 --- a/mmflow/models/decoders/__init__.py +++ b/mmflow/models/decoders/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .context_net import ContextNet +from .flow1d_decoder import Flow1DDecoder from .flownet_decoder import FlowNetCDecoder, FlowNetSDecoder from .gma_decoder import GMADecoder from .irr_refine import FlowRefine, OccRefine, OccShuffleUpsample @@ -13,5 +14,5 @@ 'FlowNetCDecoder', 'FlowNetSDecoder', 'PWCNetDecoder', 'MaskFlowNetSDecoder', 'NetE', 'ContextNet', 'RAFTDecoder', 'FlowRefine', 'OccRefine', 'OccShuffleUpsample', 'IRRPWCDecoder', 'MaskFlowNetDecoder', - 'GMADecoder' + 'GMADecoder', 'Flow1DDecoder' ] diff --git a/mmflow/models/decoders/flow1d_decoder.py b/mmflow/models/decoders/flow1d_decoder.py new file mode 100644 index 00000000..ce8f7317 --- /dev/null +++ b/mmflow/models/decoders/flow1d_decoder.py @@ -0,0 +1,258 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmflow.registry import MODELS +from ..utils.attention1d import Attention1D +from ..utils.correlation1d import Correlation1D +from .raft_decoder import MotionEncoder, RAFTDecoder, XHead + + +class MotionEncoderFlow1D(MotionEncoder): + """The module of motion encoder for Flow1D. + + An encoder which consists of several convolution layers and outputs + features as GRU's input. + + Args: + num_levels (int): Number of levels used when calculating correlation + tensor. Default: 32. + radius (int): Radius used when calculating correlation tensor. + Default: 4. + net_type (str): Type of the net. Choices: ['Basic', 'Small']. + Default: 'Basic'. + """ + + def __init__(self, + radius: int = 32, + net_type: str = 'Basic', + **kwargs) -> None: + super().__init__(radius=radius, net_type=net_type, **kwargs) + corr_channels = self._corr_channels.get(net_type) if isinstance( + self._corr_channels[net_type], + (tuple, list)) else [self._corr_channels[net_type]] + corr_kernel = self._corr_kernel.get(net_type) if isinstance( + self._corr_kernel.get(net_type), + (tuple, list)) else [self._corr_kernel.get(net_type)] + corr_padding = self._corr_padding.get(net_type) if isinstance( + self._corr_padding.get(net_type), + (tuple, list)) else [self._corr_padding.get(net_type)] + + corr_inch = 2 * (2 * radius + 1) + corr_net = self._make_encoder(corr_inch, corr_channels, corr_kernel, + corr_padding, **kwargs) + self.corr_net = nn.Sequential(*corr_net) + + +class PositionEmbeddingSine(nn.Module): + """refer to the standard version of position embedding used by the + Attention is all you need paper, generalized to work on images. + + https://github.com/facebookresearch/detr/blob/main/models/position_encod + """ + + def __init__(self, + num_pos_feats=64, + temperature=10000, + normalize=True, + scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError('normalize should be True if scale is passed') + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange( + self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +@MODELS.register_module() +class Flow1DDecoder(RAFTDecoder): + """The decoder of Flow1D Net. + + The decoder of Flow1D Net, which outputs list of upsampled flow estimation. + + Args: + net_type (str): Type of the net. Choices: ['Basic', 'Small']. + radius (int): Radius used when calculating correlation tensor. + iters (int): Total iteration number of iterative update of RAFTDecoder. + corr_op_cfg (dict): Config dict of correlation operator. + Default: dict(type='CorrLookup'). + gru_type (str): Type of the GRU module. Choices: ['Conv', 'SeqConv']. + Default: 'SeqConv'. + feat_channels (Sequence(int)): features channels of prediction module. + mask_channels (int): Output channels of mask prediction layer. + Default: 64. + conv_cfg (dict, optional): Config dict of convolution layers in motion + encoder. Default: None. + norm_cfg (dict, optional): Config dict of norm layer in motion encoder. + Default: None. + act_cfg (dict, optional): Config dict of activation layer in motion + encoder. Default: None. + """ + _h_channels = {'Basic': 128, 'Small': 96} + _cxt_channels = {'Basic': 128, 'Small': 64} + + def __init__(self, + net_type: str, + radius: int, + corr_op_cfg: dict = dict( + type='CorrLookupFlow1D', align_corners=True), + feat_channels: Union[int, Sequence[int]] = 256, + mask_channels: int = 64, + conv_cfg: Optional[dict] = None, + norm_cfg: Optional[dict] = None, + act_cfg: Optional[dict] = None, + **kwargs) -> None: + super().__init__( + net_type=net_type, + num_levels=4, + radius=radius, + corr_op_cfg=corr_op_cfg, + feat_channels=feat_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs) + self.attn_block_x = Attention1D( + in_channels=feat_channels, + y_attention=False, + double_cross_attn=True) + self.attn_block_y = Attention1D( + in_channels=feat_channels, + y_attention=True, + double_cross_attn=True) + self.corr_block = Correlation1D() + self.feat_channels = feat_channels if isinstance( + tuple, list) else [feat_channels] + + self.encoder = MotionEncoderFlow1D( + radius=radius, + net_type=net_type, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + feat_channels = feat_channels if isinstance(tuple, + list) else [feat_channels] + self.mask_channels = mask_channels * 9 + if net_type == 'Basic': + self.mask_pred = XHead( + self.h_channels, feat_channels, self.mask_channels, x='mask') + + def _upsample(self, + flow: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex + combination. + + Args: + flow (Tensor): The optical flow with the shape [N, 2, H/8, W/8]. + mask (Tensor, optional): The learnable mask with shape + [N, grid_size x scale x scale, H/8, H/8]. + + Returns: + Tensor: The output optical flow with the shape [N, 2, H, W]. + """ + scale = 8 + grid_size = 9 + grid_side = int(math.sqrt(grid_size)) + N, _, H, W = flow.shape + if mask is None: + new_size = (scale * H, scale * W) + return scale * F.interpolate( + flow, size=new_size, mode='bilinear', align_corners=True) + # predict a (Nx8×8×9xHxW) mask + mask = mask.view(N, 1, grid_size, scale, scale, H, W) + mask = torch.softmax(mask, dim=2) + + # extract local grid with 3x3 side padding = grid_side//2 + upflow = F.unfold(scale * flow, [grid_side, grid_side], padding=1) + # upflow with shape N, 2, 9, 1, 1, H, W + upflow = upflow.view(N, 2, grid_size, 1, 1, H, W) + + # take a weighted combination over the neighborhood grid 3x3 + # upflow with shape N, 2, 8, 8, H, W + upflow = torch.sum(mask * upflow, dim=2) + upflow = upflow.permute(0, 1, 4, 2, 5, 3) + return upflow.reshape(N, 2, scale * H, scale * W) + + def forward(self, feat1: torch.Tensor, feat2: torch.Tensor, + flow: torch.Tensor, h_feat: torch.Tensor, + cxt_feat: torch.Tensor) -> Sequence[torch.Tensor]: + """Forward function for Flow1D. + + Args: + feat1 (Tensor): The feature from the first input image. + feat2 (Tensor): The feature from the second input image. + flow (Tensor): The initialized flow when warm start. + h (Tensor): The hidden state for GRU cell. + cxt_feat (Tensor): The contextual feature from the first image. + + Returns: + Sequence[Tensor]: The list of predicted optical flow. + """ + pos_encoding = PositionEmbeddingSine(self.feat_channels[0] // 2) + position = pos_encoding(feat1) + + # attention + feat2_x, _ = self.attn_block_x(feat1, feat2, position, None) + feat2_y, _ = self.attn_block_y(feat1, feat2, position, None) + + correlation_x = self.corr_block(feat1, feat2_y, x_correlation=True) + correlation_y = self.corr_block(feat1, feat2_x, x_correlation=False) + + corrleation1d = [correlation_x, correlation_y] + upflow_preds = [] + delta_flow = torch.zeros_like(flow) + for _ in range(self.iters): + flow = flow.detach() + corr = self.corr_lookup(corrleation1d, flow) + motion_feat = self.encoder(corr, flow) + x = torch.cat([cxt_feat, motion_feat], dim=1) + h_feat = self.gru(h_feat, x) + + delta_flow = self.flow_pred(h_feat) + flow = flow + delta_flow + + if hasattr(self, 'mask_pred'): + mask = .25 * self.mask_pred(h_feat) + else: + mask = None + + upflow = self._upsample(flow, mask) + upflow_preds.append(upflow) + + return upflow_preds diff --git a/mmflow/models/flow_estimators/__init__.py b/mmflow/models/flow_estimators/__init__.py index 7020cff2..7e8052d2 100644 --- a/mmflow/models/flow_estimators/__init__.py +++ b/mmflow/models/flow_estimators/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .flow1d import Flow1D from .flownet import FlowNetC, FlowNetS from .flownet2 import FlowNet2, FlowNetCSS from .irrpwc import IRRPWC @@ -9,5 +10,5 @@ __all__ = [ 'FlowNetC', 'FlowNetS', 'LiteFlowNet', 'PWCNet', 'MaskFlowNetS', 'RAFT', - 'IRRPWC', 'FlowNet2', 'FlowNetCSS', 'MaskFlowNet' + 'IRRPWC', 'FlowNet2', 'FlowNetCSS', 'MaskFlowNet', 'Flow1D' ] diff --git a/mmflow/models/flow_estimators/flow1d.py b/mmflow/models/flow_estimators/flow1d.py new file mode 100644 index 00000000..8ccb21fd --- /dev/null +++ b/mmflow/models/flow_estimators/flow1d.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmflow.registry import MODELS +from .raft import RAFT + + +@MODELS.register_module() +class Flow1D(RAFT): + """Flow1D model. + + Args: + radius (int): Number of radius in . + cxt_channels (int): Number of channels of context feature. + h_channels (int): Number of channels of hidden feature in . + cxt_encoder (dict): Config dict for building context encoder. + freeze_bn (bool, optional): Whether to freeze batchnorm layer or not. + Default: False. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(num_levels=4, **kwargs) diff --git a/mmflow/models/utils/__init__.py b/mmflow/models/utils/__init__.py index 77a15216..d4cc2d6a 100644 --- a/mmflow/models/utils/__init__.py +++ b/mmflow/models/utils/__init__.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .corr_lookup import CorrLookup +from .attention1d import Attention1D +from .corr_lookup import CorrLookup, CorrLookupFlow1D +from .correlation1d import Correlation1D from .correlation_block import CorrBlock from .densenet import BasicDenseBlock, DenseLayer from .estimators_link import BasicLink, LinkOutput @@ -11,5 +13,6 @@ __all__ = [ 'ResLayer', 'BasicBlock', 'Bottleneck', 'BasicLink', 'LinkOutput', 'DenseLayer', 'BasicDenseBlock', 'CorrBlock', 'occlusion_estimation', - 'Warp', 'CorrLookup', 'unpack_flow_data_samples' + 'Warp', 'CorrLookup', 'unpack_flow_data_samples', 'Attention1D', + 'Correlation1D', 'CorrLookupFlow1D' ] diff --git a/mmflow/models/utils/attention1d.py b/mmflow/models/utils/attention1d.py new file mode 100644 index 00000000..7ac37aef --- /dev/null +++ b/mmflow/models/utils/attention1d.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from mmengine.model import BaseModule +from torch import Tensor, nn + + +class AttentionLayer(BaseModule): + """AttentionLayer on x or y direction, compute the self attention on y or x + direction. + + Args: + in_channels (int): Number of input channels. + y_attention (bool): Whether calculate y axis's attention or not. + """ + + def __init__(self, in_channels: int, y_attention: bool = False) -> None: + super().__init__() + self.y_attention = y_attention + + self.query_conv = nn.Conv2d(in_channels, in_channels, 1) + self.key_conv = nn.Conv2d(in_channels, in_channels, 1) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature1: Tensor, feature2: Tensor, position: Tensor, + value: Tensor) -> Tuple[Tensor, Tensor]: + """Forward function for AttentionLayer. + + Query: feature1 + position + Key: feature2 + position + Value: feature2 + Args: + feature1 (Tensor): The input feature1. + feature2 (Tensor): The input feature2. + position (Tensor): position encoding. + value (Tensor): attention value. + Returns: + Tuple[Tensor, Tensor]: The output of attention layer + and attention weights (scores). + """ + b, c, h, w = feature1.size() + + query = feature1 + position if position is not None else feature1 + query = self.query_conv(query) + + key = feature2 + position if position is not None else feature2 + key = self.key_conv(key) + + value = feature2 if value is None else value + scale_factor = c**0.5 + + if self.y_attention: + # multiple on H direction, feature shape is [B, W, H, C] + query = query.permute(0, 3, 2, 1) + key = key.permute(0, 3, 1, 2) + value = value.permute(0, 3, 2, 1) + else: # x attention + # multiple on W direction, feature shape is [B, W, H, C] + query = query.permute(0, 2, 3, 1) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 3, 1) + + # the shape of attention is [B, W, H, H] or [B, H, W, W] + scores = torch.matmul(query, key) / scale_factor + scores = torch.softmax(scores, dim=-1) + + out = torch.matmul(scores, value) + + # the shape of output is [B, C, H, W] + if self.y_attention: + out = out.permute(0, 3, 2, 1).contiguous() + else: + out = out.permute(0, 3, 1, 2).contiguous() + + return out, scores + + +class Attention1D(BaseModule): + """Cross-Attention on x or y direction, without multi-head and dropout + support for faster speed First compute y or x direction self attention, + then compute x or y direction cross attention. + + Args: + in_channels (int): Number of input channels. + y_attention (bool): Whether calculate y axis's attention or not + double_cross_attn (bool): Whether calculate self attention or not + """ + + def __init__(self, + in_channels: int, + y_attention: bool = False, + double_cross_attn: bool = True) -> None: + super().__init__() + self.y_attention = y_attention + self.double_cross_attn = double_cross_attn + if double_cross_attn: + self.self_attn = AttentionLayer(in_channels, not y_attention) + self.cross_attn = AttentionLayer(in_channels, y_attention) + + def forward(self, feature1: Tensor, feature2: Tensor, position: Tensor, + value: Tensor) -> Tuple[Tensor, Tensor]: + """Forward function for Attention1D. + + Args: + feature1 (Tensor): The input feature1. + feature2 (Tensor): The input feature2. + position (Tensor): position encoding. + value (Tensor): attention value. + Returns: + Tuple[Tensor, Tensor]: The output of attention layer + and attention weights (scores). + """ + if self.double_cross_attn: + feature1 = self.self_attn(feature1, feature1, position, value)[0] + return self.cross_attn(feature1, feature2, position, value) diff --git a/mmflow/models/utils/corr_lookup.py b/mmflow/models/utils/corr_lookup.py index 4cbcc8b0..afe5bfc5 100644 --- a/mmflow/models/utils/corr_lookup.py +++ b/mmflow/models/utils/corr_lookup.py @@ -138,3 +138,94 @@ def forward(self, corr_pyramid: Sequence[Tensor], flow: Tensor) -> Tensor: out = torch.cat(out_corr_pyramid, dim=-1) return out.permute(0, 3, 1, 2).contiguous().float() + + +@MODELS.register_module() +class CorrLookupFlow1D(nn.Module): + """Correlation lookup operator for Flow1D. + This operator is used in `Flow1D`_ + Args: + radius (int): the radius of the local neighborhood of the pixels. + Default to 32. + mode (str): interpolation mode to calculate output values 'bilinear' + | 'nearest' | 'bicubic'. Default: 'bilinear' Note: mode='bicubic' + supports only 4-D input. + padding_mode (str): padding mode for outside grid values 'zeros' | + 'border' | 'reflection'. Default: 'zeros' + align_corners (bool): If set to True, the extrema (-1 and 1) are + considered as referring to the center points of the input’s corner + pixels. If set to False, they are instead considered as referring + to the corner points of the input’s corner pixels, making the + sampling more resolution agnostic. Default to True. + """ + + def __init__(self, + radius: int = 32, + mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = True) -> None: + super().__init__() + self.r = radius + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + + def forward(self, corr: Sequence[Tensor], flow: Tensor) -> Tensor: + """Forward function of Correlation lookup for Flow1D. + + Args: + corr (Sequence[Tensor]): Correlation on x and y direction. + flow (Tensor): Current estimated optical flow. + Returns: + Tensor: lookup cost volume on the correlation of x and y directions + concatenate together. + """ + B, _, H, W = flow.shape + # reshape corr_x from [B, H, W, W] to [B*H*W, 1, 1, W] + corr_x = corr[0].view(-1, 1, 1, W) + # reshape corr_y from [B, W, H, H]to [B*H*W, 1, H, 1] + corr_y = corr[1].permute(0, 2, 1, 3).contiguous().view(-1, 1, H, 1) + + # reshape flow to [B, H, W, 2] + flow = flow.permute(0, 2, 3, 1) + + dx = torch.linspace( + -self.r, self.r, 2 * self.r + 1, device=flow.device) + dy = torch.linspace( + -self.r, self.r, 2 * self.r + 1, device=flow.device) + + delta_x = torch.stack((dx, torch.zeros_like(dx)), dim=-1) + delta_y = torch.stack((torch.zeros_like(dy), dy), dim=-1) + # # [1, 2r+1, 1, 2] + delta_y = delta_y.unsqueeze(1).unsqueeze(0) + + xx = torch.arange(0, W, device=flow.device) + yy = torch.arange(0, H, device=flow.device) + coords = coords_grid(B, xx, yy).permute(0, 2, 3, 1) + flow + + coords_x = coords[:, :, :, 0] + coords_y = coords[:, :, :, 1] + + coords_x = torch.stack((coords_x, torch.zeros_like(coords_x)), dim=-1) + coords_y = torch.stack((torch.zeros_like(coords_y), coords_y), dim=-1) + + centroid_x = coords_x.view(B * H * W, 1, 1, 2) + centroid_y = coords_y.view(B * H * W, 1, 1, 2) + + coords_x = centroid_x + delta_x + coords_y = centroid_y + delta_y + + corr_x = bilinear_sample(corr_x, coords_x, self.mode, + self.padding_mode, self.align_corners) + corr_y = bilinear_sample(corr_y, coords_y, self.mode, + self.padding_mode, self.align_corners) + + # shape is [B, 2r+1, H, W] + corr_x = corr_x.view(B, H, W, -1) + corr_x = corr_x.permute(0, 3, 1, 2).contiguous() + corr_y = corr_y.view(B, H, W, -1) + corr_y = corr_y.permute(0, 3, 1, 2).contiguous() + + correlation = torch.cat((corr_x, corr_y), dim=1) + + return correlation diff --git a/mmflow/models/utils/correlation1d.py b/mmflow/models/utils/correlation1d.py new file mode 100644 index 00000000..849f8ee7 --- /dev/null +++ b/mmflow/models/utils/correlation1d.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +from mmengine.model import BaseModule +from torch import Tensor + + +class Correlation1D(BaseModule): + """Correlation1D Module. + + The neck of Flow1D, which calculates correlation tensor of input features + with the method of 3D cost volume. + """ + + def __init__(self): + super().__init__() + + def forward(self, + feat1: Tensor, + feat2: Tensor, + x_correlation: bool = False) -> Tensor: + """Forward function for Correlation1D. + + Args: + feat1 (Tensor): The feature from first input image. + feat2 (Tensor): The 1D cross attention feat2 on x or y direction. + x_correlation (bool): whether x correlation or not. + Returns: + Tensor: Correlation of x correlation or y correlation. + """ + b, c, h, w = feat1.shape + scale_factor = c**0.5 + + if x_correlation: + # x correlation, corr shape is [B, H, W, W] + feat1 = feat1.permute(0, 2, 3, 1) + feat2 = feat2.permute(0, 2, 1, 3) + else: + # y correlation, corr shape is [B, W, H, H] + feat1 = feat1.permute(0, 3, 2, 1) + feat2 = feat2.permute(0, 3, 1, 2) + + corr = torch.matmul(feat1, feat2) / scale_factor + return corr diff --git a/tests/test_models/test_decoders/test_flow1d_decoder.py b/tests/test_models/test_decoders/test_flow1d_decoder.py new file mode 100644 index 00000000..393382d4 --- /dev/null +++ b/tests/test_models/test_decoders/test_flow1d_decoder.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +from mmengine.structures import PixelData +from mmengine.utils import is_list_of + +from mmflow.models.decoders.flow1d_decoder import (Flow1DDecoder, + MotionEncoderFlow1D) +from mmflow.structures import FlowDataSample + + +@pytest.mark.parametrize('net_type', ['Basic', 'Small']) +def test_motion_encoder_flow1d(net_type): + + # test invalid net_type + with pytest.raises(AssertionError): + MotionEncoderFlow1D(net_type='invalid value') + + module = MotionEncoderFlow1D( + net_type=net_type, conv_cfg=None, norm_cfg=None, act_cfg=None) + radius = 32 + + input_corr = torch.randn((1, 2 * (2 * radius + 1), 56, 56)) + input_flow = torch.randn((1, 2, 56, 56)) + + corr_feat = module.corr_net(input_corr) + flow_feat = module.flow_net(input_flow) + our_feat = module.out_net(torch.cat([corr_feat, flow_feat], dim=1)) + + if net_type == 'Basic': + assert corr_feat.shape == torch.Size((1, 192, 56, 56)) + assert flow_feat.shape == torch.Size((1, 64, 56, 56)) + assert our_feat.shape == torch.Size((1, 126, 56, 56)) + elif net_type == 'Small': + assert corr_feat.shape == torch.Size((1, 96, 56, 56)) + assert flow_feat.shape == torch.Size((1, 32, 56, 56)) + assert our_feat.shape == torch.Size((1, 80, 56, 56)) + + +def test_flow1d_decoder(): + model = Flow1DDecoder( + net_type='Basic', + radius=32, + iters=12, + flow_loss=dict(type='SequenceLoss')) + mask = torch.ones((1, 64 * 9, 10, 10)) + flow = torch.randn((1, 2, 10, 10)) + assert model._upsample(flow, mask).shape == torch.Size((1, 2, 80, 80)) + + feat1 = torch.randn(1, 256, 8, 8) + feat2 = torch.randn(1, 256, 8, 8) + h_feat = torch.randn(1, 128, 8, 8) + cxt_feat = torch.randn(1, 128, 8, 8) + flow = torch.zeros((1, 2, 8, 8)) + + h = 64 + w = 64 + metainfo = dict(img_shape=(h, w, 3), ori_shape=(h, w)) + data_sample = FlowDataSample(metainfo=metainfo) + data_sample.gt_flow_fw = PixelData(**dict(data=torch.randn(2, h, w))) + data_samples = [data_sample] + + # test forward function + out = model(feat1, feat2, flow, h_feat, cxt_feat) + assert isinstance(out, list) + assert out[0].shape == torch.Size((1, 2, 64, 64)) + + # test loss forward + loss = model.loss( + feat1, feat2, flow, h_feat, cxt_feat, data_samples=data_samples) + assert float(loss['loss_flow']) > 0. + + # test predict forward + out = model.predict( + feat1, feat2, flow, h_feat, cxt_feat, data_samples=data_samples) + assert out[0].pred_flow_fw.shape == (64, 64) + assert isinstance(out, list) and is_list_of(out, FlowDataSample) diff --git a/tests/test_models/test_flow_estimators.py b/tests/test_models/test_flow_estimators.py index b9fb5918..95bf0526 100644 --- a/tests/test_models/test_flow_estimators.py +++ b/tests/test_models/test_flow_estimators.py @@ -72,6 +72,7 @@ def test_flow_estimator(cfg_file): @pytest.mark.parametrize('cfg_file', [ '../../configs/_base_/models/raft.py', + '../../configs/_base_/models/flow1d.py', '../../configs/_base_/models/flownets.py', '../../configs/_base_/models/flownet2/flownet2sd.py', '../../configs/_base_/models/gma/gma.py', @@ -85,7 +86,7 @@ def test_flow_estimator_without_cuda(cfg_file): cfg_file = osp.join(osp.dirname(__file__), cfg_file) cfg = Config.fromfile(cfg_file) - if cfg.model.type == 'RAFT': + if cfg.model.type == 'RAFT' or cfg.model.type == 'Flow1D': # Replace SyncBN with BN to inference on CPU cfg.model.cxt_encoder.norm_cfg = dict(type='BN', requires_grad=True) @@ -95,7 +96,7 @@ def test_flow_estimator_without_cuda(cfg_file): # test tensor out out = estimator(inputs, data_samples, mode='tensor') - if cfg.model.type == 'RAFT': + if cfg.model.type == 'RAFT' or cfg.model.type == 'Flow1D': assert is_list_of(out, Tensor) else: assert isinstance(out, dict) diff --git a/tests/test_models/test_utils/test_corr_lookup.py b/tests/test_models/test_utils/test_corr_lookup.py index a73d892b..05c92c78 100644 --- a/tests/test_models/test_utils/test_corr_lookup.py +++ b/tests/test_models/test_utils/test_corr_lookup.py @@ -4,6 +4,7 @@ from mmflow.models import build_components from mmflow.models.decoders.raft_decoder import CorrelationPyramid from mmflow.models.utils.corr_lookup import bilinear_sample, coords_grid +from mmflow.models.utils.correlation1d import Correlation1D def test_coords_grid(): @@ -56,3 +57,27 @@ def test_corr_lookup(): corr_lpt = corr_lookup_op(corr_pyramid, torch.randn(1, 2, H, W)) assert corr_lpt.shape == torch.Size((1, 81 * 4, H, W)) + + +def test_corr_lookup_flow1d(): + corr_block = Correlation1D() + feat1 = torch.arange(0, 24) + feat1 = feat1.view(1, 2, 3, 4) + feat2 = feat1 + 1 + flow = torch.ones_like(feat1) + b, _, h, w = feat1.size() + radius = 32 + + correlation_x = corr_block(feat1, feat2, True) + correlation_y = corr_block(feat1, feat2, False) + correlation = [correlation_x, correlation_y] + corr_lookup_cfg = dict( + type='CorrLookupFlow1D', + radius=radius, + mode='bilinear', + padding_mode='zeros', + align_corners=True) + corr_lookup_op = build_components(corr_lookup_cfg) + + corr_xy = corr_lookup_op(correlation, flow) + assert corr_xy.size() == (b, 2 * (2 * radius + 1), h, w)