Skip to content

Commit

Permalink
knock off a small todo
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 6, 2023
1 parent 45033ad commit d13d7f1
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
- [x] build out CLI tool for training, resuming training off config file
- [x] allow for temporal interpolation at specific stages
- [x] make sure temporal interpolation works with inpainting
- [x] make sure one can customize all interpolation modes (some researchers are finding better results with trilinear)

- [ ] reread <a href="https://arxiv.org/abs/2205.15868">cogvideo</a> and figure out how frame rate conditioning could be used
- [ ] bring in attention expertise for self attention layers in unet3d
Expand All @@ -724,7 +725,6 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
- [ ] cleanup self conditioning to be extracted at imagen instantiation
- [ ] make sure eventual dreambooth works with imagen-video
- [ ] add framerate conditioning for video diffusion
- [ ] make sure one can customize all interpolation modes (some researchers are finding better results with trilinear)
- [ ] imagen-video : allow for conditioning on preceding (and possibly future) frames of videos. ignore time should not be allowed in that scenario

## Citations
Expand Down
3 changes: 3 additions & 0 deletions imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
channels = 3,
cond_drop_prob = 0.1,
random_crop_sizes = None,
resize_mode = 'nearest',
temporal_downsample_factor = 1,
lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
Expand Down Expand Up @@ -165,7 +166,9 @@ def __init__(
self.is_video = is_video

self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))

self.resize_to = resize_video_to if is_video else resize_image_to
self.resize_to = partial(self.resize_to, mode = resize_mode)

# unet image sizes

Expand Down
12 changes: 10 additions & 2 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,7 @@ def __init__(
final_conv_kernel_size = 3,
cosine_sim_attn = False,
self_cond = False,
resize_mode = 'nearest',
combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
pixel_shuffle_upsample = True, # may address checkboard artifacts
):
Expand Down Expand Up @@ -1426,6 +1427,10 @@ def __init__(

zero_init_(self.final_conv)

# resize mode

self.resize_mode = resize_mode

# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
Expand Down Expand Up @@ -1541,7 +1546,7 @@ def forward(

if exists(cond_images):
assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
cond_images = resize_image_to(cond_images, x.shape[-1])
cond_images = resize_image_to(cond_images, x.shape[-1], mode = self.resize_mode)
x = torch.cat((cond_images, x), dim = 1)

# initial convolution
Expand Down Expand Up @@ -1794,7 +1799,8 @@ def __init__(
dynamic_thresholding = True,
dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
only_train_unet_number = None,
temporal_downsample_factor = 1
temporal_downsample_factor = 1,
resize_mode = 'nearest'
):
super().__init__()

Expand Down Expand Up @@ -1902,7 +1908,9 @@ def __init__(
self.is_video = is_video

self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))

self.resize_to = resize_video_to if is_video else resize_image_to
self.resize_to = partial(self.resize_to, mode = self.resize_mode)

# temporal interpolation

Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.20.2'
__version__ = '1.20.3'

0 comments on commit d13d7f1

Please sign in to comment.