Skip to content

Commit

Permalink
use feature map size to denote which layer to add self attention in d…
Browse files Browse the repository at this point in the history
…iscriminator, insert one self attention layer at 16x16
  • Loading branch information
lucidrains committed Dec 9, 2022
1 parent 818e5e4 commit 57d2a17
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
24 changes: 15 additions & 9 deletions phenaki_pytorch/cvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,20 @@ def forward(self, x):
x = (x + res) * (1 / math.sqrt(2))
return x


class Discriminator(nn.Module):
def __init__(
self,
*,
dim,
image_size,
channels = 3,
attn_layers = None,
attn_res_layers = (16,),
max_dim = 512
):
super().__init__()
num_layers = int(math.log2(min(pair(image_size))) - 1)
attn_layers = cast_tuple(attn_layers, num_layers)
assert len(attn_layers) == num_layers
num_layers = int(math.log2(min(pair(image_size))) - 2)
attn_res_layers = cast_tuple(attn_res_layers, num_layers)

blocks = []

Expand All @@ -159,33 +159,39 @@ def __init__(
blocks = []
attn_blocks = []

for ind, ((in_chan, out_chan), layer_has_attn) in enumerate(zip(layer_dims_in_out, attn_layers)):
image_resolution = image_size

for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
num_layer = ind + 1
is_not_last = ind != (len(layer_dims_in_out) - 1)

block = DiscriminatorBlock(in_chan, out_chan, downsample = is_not_last)
blocks.append(block)

attn_block = None
if layer_has_attn:
if image_resolution in attn_res_layers:
attn_block = Attention(dim = out_chan)

attn_blocks.append(attn_block)

image_resolution //= 2

self.blocks = nn.ModuleList(blocks)
self.attn_blocks = nn.ModuleList(attn_blocks)

dim_last = layer_dims[-1]
latent_dim = 2 * 2 * dim_last
latent_dim = 4 * 4 * dim_last

self.to_logits = nn.Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding = 1),
leaky_relu(),
Rearrange('b ... -> b (...)'),
nn.Linear(latent_dim, 1),
Rearrange('b 1 -> b')
)

def forward(self, x):

for block, attn_block in zip(self.blocks, self.attn_blocks):
x = block(x)

Expand Down Expand Up @@ -225,7 +231,7 @@ def __init__(
channels = 3,
use_vgg_and_gan = True,
vgg = None,
discr_attn_layers = None,
discr_attn_res_layers = (16,),
use_hinge_loss = True,
attn_dropout = 0.,
ff_dropout = 0.
Expand Down Expand Up @@ -314,7 +320,7 @@ def __init__(
image_size = self.image_size,
dim = discr_base_dim,
channels = channels,
attn_layers = discr_attn_layers
attn_res_layers = discr_attn_res_layers
)

self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.58',
version = '0.0.59',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 57d2a17

Please sign in to comment.