Skip to content

Commit

Permalink
make sure temporal interpolation works with inpainting
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 25, 2023
1 parent 333d656 commit 066efc9
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
- [x] incorporate all learnings from make-a-video (https://makeavideo.studio/)
- [x] build out CLI tool for training, resuming training off config file
- [x] allow for temporal interpolation at specific stages
- [x] make sure temporal interpolation works with inpainting

- [ ] reread <a href="https://arxiv.org/abs/2205.15868">cogvideo</a> and figure out how frame rate conditioning could be used
- [ ] bring in attention expertise for self attention layers in unet3d
Expand All @@ -718,7 +719,6 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
- [ ] add textual inversion
- [ ] cleanup self conditioning to be extracted at imagen instantiation
- [ ] make sure eventual dreambooth works with imagen-video
- [ ] make sure temporal interpolation works with inpainting

## Citations

Expand Down
16 changes: 10 additions & 6 deletions imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
pad_tuple_to_length,
resize_image_to,
calc_all_frame_dims,
safe_get_tuple_index,
right_pad_dims_to,
module_device,
normalize_neg_one_to_one,
Expand Down Expand Up @@ -397,6 +398,12 @@ def one_unet_sample(
sigma_max = None,
**kwargs
):
# video

is_video = len(shape) == 5
frames = shape[-3] if is_video else None
resize_kwargs = dict(target_frames = frames) if exists(frames) else dict()

# get specific sampling hyperparameters for unet

hp = self.hparams[unet_number - 1]
Expand Down Expand Up @@ -438,7 +445,7 @@ def one_unet_sample(

if has_inpainting:
inpaint_images = self.normalize_img(inpaint_images)
inpaint_images = self.resize_to(inpaint_images, shape[-1])
inpaint_images = self.resize_to(inpaint_images, shape[-1], **resize_kwargs)
inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool()

# unet kwargs
Expand Down Expand Up @@ -600,10 +607,7 @@ def sample(

# determine the frame dimensions, if needed

if self.is_video:
all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)
else:
all_frame_dims = (tuple(),) * num_unets
all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)

# initializing with an image or video

Expand Down Expand Up @@ -744,7 +748,7 @@ def forward(
batch_size, c, *_, h, w, device, is_video = *images.shape, images.device, (images.ndim == 5)

frames = images.shape[2] if is_video else None
all_frame_dims = tuple(el[0] for el in calc_all_frame_dims(self.temporal_downsample_factor, frames))
all_frame_dims = tuple(safe_get_tuple_index(el, 0) for el in calc_all_frame_dims(self.temporal_downsample_factor, frames))
ignore_time = kwargs.get('ignore_time', False)

target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None
Expand Down
23 changes: 17 additions & 6 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ def calc_all_frame_dims(
downsample_factors: List[int],
frames
):
if not exists(frames):
return (tuple(),) * len(downsample_factors)

all_frame_dims = []

for divisor in downsample_factors:
Expand All @@ -170,6 +173,11 @@ def calc_all_frame_dims(

return all_frame_dims

def safe_get_tuple_index(tup, index, default = None):
if len(tup) <= index:
return default
return tup[index]

# image normalization functions
# ddpms expect images to be in the range of -1 to 1

Expand Down Expand Up @@ -2108,6 +2116,12 @@ def p_sample_loop(
batch = shape[0]
img = torch.randn(shape, device = device)

# video

is_video = len(shape) == 5
frames = shape[-3] if is_video else None
resize_kwargs = dict(target_frames = frames) if exists(frames) else dict()

# for initialization with an image or video

if exists(init_images):
Expand All @@ -2124,7 +2138,7 @@ def p_sample_loop(

if has_inpainting:
inpaint_images = self.normalize_img(inpaint_images)
inpaint_images = self.resize_to(inpaint_images, shape[-1])
inpaint_images = self.resize_to(inpaint_images, shape[-1], **resize_kwargs)
inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool()

# time
Expand Down Expand Up @@ -2260,10 +2274,7 @@ def sample(

assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'

if self.is_video:
all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)
else:
all_frame_dims = (tuple(),) * num_unets
all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)

frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()

Expand Down Expand Up @@ -2529,7 +2540,7 @@ def forward(
assert h >= target_image_size and w >= target_image_size

frames = images.shape[2] if is_video else None
all_frame_dims = tuple(el[0] for el in calc_all_frame_dims(self.temporal_downsample_factor, frames))
all_frame_dims = tuple(safe_get_tuple_index(el, 0) for el in calc_all_frame_dims(self.temporal_downsample_factor, frames))
ignore_time = kwargs.get('ignore_time', False)

target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.20.0'
__version__ = '1.20.1'

0 comments on commit 066efc9

Please sign in to comment.