From 0645977e74c307bfe4e60dda26f264a59dc3ca3f Mon Sep 17 00:00:00 2001 From: Kye Date: Mon, 5 Feb 2024 19:17:24 -0800 Subject: [PATCH] [BUGF][]AttentionBasedInflationBlock] --- example.py | 23 ++++++++++++++++++----- lumiere/model.py | 47 +++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 61 insertions(+), 9 deletions(-) diff --git a/example.py b/example.py index 9920b71..3195a75 100644 --- a/example.py +++ b/example.py @@ -1,11 +1,24 @@ import torch from lumiere.model import ConvolutionBasedInflationBlock -# Example usage: -# scale_factor must be a divisor of T, H, and W for the example to work correctly + +# B, T, H, W, C +x = torch.randn(1, 2, 112, 112, 3) + +# Create a ConvolutionBasedInflationBlock block = ConvolutionBasedInflationBlock( - 3, 64, (3, 3), (2, 2), (1, 1), scale_factor=2 + in_channels=3, + out_channels=64, + kernel_size=(3, 3), + stride=1, + padding=1, + scale_factor=2, ) -x = torch.randn(1, 4, 224, 224, 3) + + +# Pass the input tensor through the block out = block(x) -print(out.shape) # Expected shape: [1, 2, 112, 112, 64] + + +# Print the output shape +print(out.shape) diff --git a/lumiere/model.py b/lumiere/model.py index ad3ca99..6867e17 100644 --- a/lumiere/model.py +++ b/lumiere/model.py @@ -110,6 +110,33 @@ def forward(self, x: Tensor): class AttentionBasedInflationBlock(nn.Module): + """ + Attention-based inflation block module. + + Args: + dim (int): The input dimension. + heads (int): The number of attention heads. + dropout (float, optional): The dropout rate. Defaults to 0.1. + + Attributes: + dim (int): The input dimension. + heads (int): The number of attention heads. + dropout (float): The dropout rate. + attn (SpatialLinearAttention): The spatial linear attention module. + proj (nn.Linear): The linear projection layer. + norm (nn.LayerNorm): The layer normalization module. + + Example: + >>> import torch + >>> from lumiere.model import AttentionBasedInflationBlock + >>> x = torch.randn(1, 4, 224, 224, 512) + >>> model = AttentionBasedInflationBlock(dim=512, heads=4, dropout=0.1) + >>> out = model(x) + >>> print(out.shape) + torch.Size([1, 4, 224, 224, 512]) + + """ + def __init__( self, dim: int, @@ -124,7 +151,7 @@ def __init__( self.dropout = dropout # Spatial linear attention for videos of size: - #batch_size, channels, frames, height, width). + # batch_size, channels, frames, height, width. self.attn = SpatialLinearAttention( dim, heads, @@ -138,6 +165,16 @@ def __init__( self.norm = nn.LayerNorm(dim) def forward(self, x: Tensor): + """ + Forward pass of the AttentionBasedInflationBlock. + + Args: + x (Tensor): The input tensor. + + Returns: + Tensor: The output tensor. + + """ skip = x b, t, h, w, d = x.shape @@ -147,8 +184,10 @@ def forward(self, x: Tensor): # Apply spatial linear attention x = self.attn(x) + # Reshape back to the original shape + x = rearrange(x, "b d t h w -> b t h w d") + # Linear projection - # x = self.proj(x) - # x = self.norm(x) + x = nn.Linear(d, d)(x) - return x #+ skip + return x + skip