diff --git a/attn_block_example.py b/attn_block_example.py index a70b1c0..49478a1 100644 --- a/attn_block_example.py +++ b/attn_block_example.py @@ -1,16 +1,14 @@ -import torch +import torch from lumiere.model import AttentionBasedInflationBlock # B, T, H, W, D x = torch.randn(1, 4, 224, 224, 512) # Model -model = AttentionBasedInflationBlock( - dim=512, heads=4, dropout=0.1 -) +model = AttentionBasedInflationBlock(dim=512, heads=4, dropout=0.1) # Forward pass out = model(x) # print -print(out.shape) # Expected shape: [1, 4, 224, 224, 3] \ No newline at end of file +print(out.shape) # Expected shape: [1, 4, 224, 224, 3] diff --git a/lumiere/model.py b/lumiere/model.py index 6867e17..01984b5 100644 --- a/lumiere/model.py +++ b/lumiere/model.py @@ -1,8 +1,8 @@ -from einops import rearrange, reduce +from einops import rearrange from torch import nn, Tensor from zeta.nn.attention import SpatialLinearAttention -from einops import rearrange, reduce +from einops import rearrange from torch import nn, Tensor @@ -125,9 +125,9 @@ class AttentionBasedInflationBlock(nn.Module): attn (SpatialLinearAttention): The spatial linear attention module. proj (nn.Linear): The linear projection layer. norm (nn.LayerNorm): The layer normalization module. - + Example: - >>> import torch + >>> import torch >>> from lumiere.model import AttentionBasedInflationBlock >>> x = torch.randn(1, 4, 224, 224, 512) >>> model = AttentionBasedInflationBlock(dim=512, heads=4, dropout=0.1) @@ -150,7 +150,7 @@ def __init__( self.heads = heads self.dropout = dropout - # Spatial linear attention for videos of size: + # Spatial linear attention for videos of size: # batch_size, channels, frames, height, width. self.attn = SpatialLinearAttention( dim, @@ -177,16 +177,16 @@ def forward(self, x: Tensor): """ skip = x b, t, h, w, d = x.shape - + # Reshape to match the spatial linear attention module x = rearrange(x, "b t h w d -> b d t h w") - + # Apply spatial linear attention x = self.attn(x) # Reshape back to the original shape x = rearrange(x, "b d t h w -> b t h w d") - + # Linear projection x = nn.Linear(d, d)(x) diff --git a/pyproject.toml b/pyproject.toml index 42815b4..681a783 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "Lumiere" -version = "0.0.2" +version = "0.0.3" description = "Paper - Pytorch" license = "MIT" authors = ["Kye Gomez "]