Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Feb 6, 2024
1 parent 84ac0b9 commit dce9a30
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
8 changes: 3 additions & 5 deletions attn_block_example.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import torch
import torch
from lumiere.model import AttentionBasedInflationBlock

# B, T, H, W, D
x = torch.randn(1, 4, 224, 224, 512)

# Model
model = AttentionBasedInflationBlock(
dim=512, heads=4, dropout=0.1
)
model = AttentionBasedInflationBlock(dim=512, heads=4, dropout=0.1)

# Forward pass
out = model(x)

# print
print(out.shape) # Expected shape: [1, 4, 224, 224, 3]
print(out.shape) # Expected shape: [1, 4, 224, 224, 3]
16 changes: 8 additions & 8 deletions lumiere/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from einops import rearrange, reduce
from einops import rearrange
from torch import nn, Tensor
from zeta.nn.attention import SpatialLinearAttention

from einops import rearrange, reduce
from einops import rearrange
from torch import nn, Tensor


Expand Down Expand Up @@ -125,9 +125,9 @@ class AttentionBasedInflationBlock(nn.Module):
attn (SpatialLinearAttention): The spatial linear attention module.
proj (nn.Linear): The linear projection layer.
norm (nn.LayerNorm): The layer normalization module.
Example:
>>> import torch
>>> import torch
>>> from lumiere.model import AttentionBasedInflationBlock
>>> x = torch.randn(1, 4, 224, 224, 512)
>>> model = AttentionBasedInflationBlock(dim=512, heads=4, dropout=0.1)
Expand All @@ -150,7 +150,7 @@ def __init__(
self.heads = heads
self.dropout = dropout

# Spatial linear attention for videos of size:
# Spatial linear attention for videos of size:
# batch_size, channels, frames, height, width.
self.attn = SpatialLinearAttention(
dim,
Expand All @@ -177,16 +177,16 @@ def forward(self, x: Tensor):
"""
skip = x
b, t, h, w, d = x.shape

# Reshape to match the spatial linear attention module
x = rearrange(x, "b t h w d -> b d t h w")

# Apply spatial linear attention
x = self.attn(x)

# Reshape back to the original shape
x = rearrange(x, "b d t h w -> b t h w d")

# Linear projection
x = nn.Linear(d, d)(x)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "Lumiere"
version = "0.0.2"
version = "0.0.3"
description = "Paper - Pytorch"
license = "MIT"
authors = ["Kye Gomez <[email protected]>"]
Expand Down

0 comments on commit dce9a30

Please sign in to comment.