Skip to content

Commit

Permalink
[BUGF][]AttentionBasedInflationBlock]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Feb 6, 2024
1 parent 353b1c0 commit 0645977
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 9 deletions.
23 changes: 18 additions & 5 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import torch
from lumiere.model import ConvolutionBasedInflationBlock

# Example usage:
# scale_factor must be a divisor of T, H, and W for the example to work correctly

# B, T, H, W, C
x = torch.randn(1, 2, 112, 112, 3)

# Create a ConvolutionBasedInflationBlock
block = ConvolutionBasedInflationBlock(
3, 64, (3, 3), (2, 2), (1, 1), scale_factor=2
in_channels=3,
out_channels=64,
kernel_size=(3, 3),
stride=1,
padding=1,
scale_factor=2,
)
x = torch.randn(1, 4, 224, 224, 3)


# Pass the input tensor through the block
out = block(x)
print(out.shape) # Expected shape: [1, 2, 112, 112, 64]


# Print the output shape
print(out.shape)
47 changes: 43 additions & 4 deletions lumiere/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,33 @@ def forward(self, x: Tensor):


class AttentionBasedInflationBlock(nn.Module):
"""
Attention-based inflation block module.
Args:
dim (int): The input dimension.
heads (int): The number of attention heads.
dropout (float, optional): The dropout rate. Defaults to 0.1.
Attributes:
dim (int): The input dimension.
heads (int): The number of attention heads.
dropout (float): The dropout rate.
attn (SpatialLinearAttention): The spatial linear attention module.
proj (nn.Linear): The linear projection layer.
norm (nn.LayerNorm): The layer normalization module.
Example:
>>> import torch
>>> from lumiere.model import AttentionBasedInflationBlock
>>> x = torch.randn(1, 4, 224, 224, 512)
>>> model = AttentionBasedInflationBlock(dim=512, heads=4, dropout=0.1)
>>> out = model(x)
>>> print(out.shape)
torch.Size([1, 4, 224, 224, 512])
"""

def __init__(
self,
dim: int,
Expand All @@ -124,7 +151,7 @@ def __init__(
self.dropout = dropout

# Spatial linear attention for videos of size:
#batch_size, channels, frames, height, width).
# batch_size, channels, frames, height, width.
self.attn = SpatialLinearAttention(
dim,
heads,
Expand All @@ -138,6 +165,16 @@ def __init__(
self.norm = nn.LayerNorm(dim)

def forward(self, x: Tensor):
"""
Forward pass of the AttentionBasedInflationBlock.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
skip = x
b, t, h, w, d = x.shape

Expand All @@ -147,8 +184,10 @@ def forward(self, x: Tensor):
# Apply spatial linear attention
x = self.attn(x)

# Reshape back to the original shape
x = rearrange(x, "b d t h w -> b t h w d")

# Linear projection
# x = self.proj(x)
# x = self.norm(x)
x = nn.Linear(d, d)(x)

return x #+ skip
return x + skip

0 comments on commit 0645977

Please sign in to comment.