Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamicrafter: mint op adaptation #765

Open
wants to merge 5 commits into
base: v0.3.0-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions examples/dynamicrafter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,27 @@ python convert_weight.py \
```

### Run inference
Set the CLIP ckpt path in inference yaml config files as follow:
```yaml
...
cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
params:
arch: "ViT-H-14"
freeze: true
layer: "penultimate"
pretrained: path/to/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_ms_model.ckpt

img_cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
arch: "ViT-H-14"
freeze: true
version: path/to/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_ms_model.ckpt
...
```

Launch inference:

```shell
sh scripts/run/run_infer.sh [RESUOUTION] [CKPT_PATH]
Expand All @@ -93,13 +114,15 @@ sh scripts/run/run_infer.sh [RESUOUTION] [CKPT_PATH]

### Performance

Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
We evaluate the inference performance of image-to-video generation by measuring the average sampling time per step and the total sampling time of a video.

Experiments are tested on ascend 910* with [mindspore 2.4.0 1119](https://repo.mindspore.cn/mindspore/mindspore/version/202411/20241119/master_20241119010040_b355f51f1710bb48d01d675c61b8305f14a9dcd4_newest/unified/aarch64/) graph mode.

| model name | cards | batch size | resolution | scheduler | steps | precision | jit level | graph compile |s/step | s/video |
|:-------------:|:------------: |:------------: |:------------:|:------------:|:------------:|:------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
| dynamicrafter | 1 | 1 | 16x576x1024 | DDIM | 50 | fp16 | O1 | 1~2 mins | 1.42 | 71 |
| dynamicrafter | 1 | 1 | 16x576x1024 | DDIM | 50 | fp16 | O1 | 1~2 mins | 1.48 | 74 |
| dynamicrafter | 1 | 1 | 16x320x512 | DDIM | 50 | fp16 | O1 |1~2 mins | 0.42 | 21 |
| dynamicrafter | 1 | 1 | 16x256x256 | DDIM | 50 | fp16 | O1 |1~2 mins | 0.26 | 13 |
| dynamicrafter | 1 | 1 | 16x256x256 | DDIM | 50 | fp16 | O1 |1~2 mins | 0.24 | 12 |


## References
Expand Down
2 changes: 2 additions & 0 deletions examples/dynamicrafter/configs/inference_1024_v1.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ model:
arch: "ViT-H-14"
freeze: true
layer: "penultimate"
pretrained: path/to/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_ms_model.ckpt

img_cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
arch: "ViT-H-14"
freeze: true
version: path/to/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_ms_model.ckpt

image_proj_stage_config:
target: lvdm.modules.encoders.resampler.Resampler
Expand Down
2 changes: 2 additions & 0 deletions examples/dynamicrafter/configs/inference_256_v1.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ model:
arch: "ViT-H-14"
freeze: true
layer: "penultimate"
pretrained: path/to/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_ms_model.ckpt

img_cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
arch: "ViT-H-14"
freeze: true
version: path/to/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_ms_model.ckpt

image_proj_stage_config:
target: lvdm.modules.encoders.resampler.Resampler
Expand Down
2 changes: 2 additions & 0 deletions examples/dynamicrafter/configs/inference_512_v1.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ model:
arch: "ViT-H-14"
freeze: true
layer: "penultimate"
pretrained: path/to/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_ms_model.ckpt

img_cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
arch: "ViT-H-14"
freeze: true
version: path/to/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_ms_model.ckpt

image_proj_stage_config:
target: lvdm.modules.encoders.resampler.Resampler
Expand Down
49 changes: 24 additions & 25 deletions examples/dynamicrafter/lvdm/models/ddpm3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import mindspore as ms
from mindspore import Parameter, Tensor
from mindspore import dtype as mstype
from mindspore import nn, ops
from mindspore import mint, nn, ops

from mindone.utils.config import instantiate_from_config
from mindone.utils.misc import default, exists, extract_into_tensor
Expand Down Expand Up @@ -89,7 +89,6 @@ def __init__(
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.isnan = ops.IsNan()
self.register_schedule(
given_betas=given_betas,
beta_schedule=beta_schedule,
Expand All @@ -106,8 +105,8 @@ def __init__(
self.logvar = Tensor(np.full(shape=(self.num_timesteps,), fill_value=logvar_init).astype(np.float32))
if self.learn_logvar:
self.logvar = Parameter(self.logvar, requires_grad=True)
self.mse_mean = nn.MSELoss(reduction="mean")
self.mse_none = nn.MSELoss(reduction="none")
self.mse_mean = mint.nn.MSELoss(reduction="mean")
self.mse_none = mint.nn.MSELoss(reduction="none")

def register_schedule(
self,
Expand Down Expand Up @@ -320,8 +319,8 @@ def decode_core(self, z, **kwargs):
if self.encoder_type == "2d" and z.dim() == 5:
b, _, t, _, _ = z.shape
# z = rearrange(z, 'b c t h w -> (b t) c h w')
z = ops.transpose(z, (0, 2, 1, 3, 4)) # (b c t h w) -> (b t c h w)
z = ops.reshape(z, (-1, z.shape[2], z.shape[3], z.shape[4])) # (b t c h w) -> ((b t) c h w)
z = mint.permute(z, (0, 2, 1, 3, 4)) # (b c t h w) -> (b t c h w)
z = mint.reshape(z, (-1, z.shape[2], z.shape[3], z.shape[4])) # (b t c h w) -> ((b t) c h w)
reshape_back = True
else:
reshape_back = False
Expand All @@ -335,11 +334,11 @@ def decode_core(self, z, **kwargs):
frame_z = 1.0 / self.scale_factor * z[index : index + 1, :, :, :]
frame_result = self.first_stage_model.decode(frame_z)
results.append(frame_result)
results = ops.cat(results, axis=0)
results = mint.cat(results, dim=0)

if reshape_back:
results = ops.reshape(results, (b, t, *results.shape[1:])) # ((b t) c h w) -> (b t c h w)
results = ops.transpose(results, (0, 2, 1, 3, 4)) # (b t c h w) -> (b c t h w)
results = mint.reshape(results, (b, t, *results.shape[1:])) # ((b t) c h w) -> (b t c h w)
results = mint.permute(results, (0, 2, 1, 3, 4)) # (b t c h w) -> (b c t h w)
return results

def decode_first_stage(self, z, **kwargs):
Expand All @@ -349,8 +348,8 @@ def encode_first_stage(self, x):
if self.encoder_type == "2d" and x.dim() == 5:
b, _, t, _, _ = x.shape
# x = rearrange(x, 'b c t h w -> (b t) c h w')
x = ops.transpose(x, (0, 2, 1, 3, 4)) # (b t c h w)
x = ops.reshape(x, (-1, *x.shape[2:])) # ((b t) c h w)
x = mint.permute(x, (0, 2, 1, 3, 4)) # (b t c h w)
x = mint.reshape(x, (-1, *x.shape[2:])) # ((b t) c h w)
reshape_back = True
else:
reshape_back = False
Expand All @@ -365,11 +364,11 @@ def encode_first_stage(self, x):
self.scale_factor * self.first_stage_model.encode(x[index : index + 1, :, :, :])
)
results.append(frame_result)
results = ops.cat(results, axis=0)
results = mint.cat(results, dim=0)

if reshape_back:
results = ops.reshape(results, (b, t, *results.shape[1:])) # (b t c h w)
results = ops.transpose(results, (0, 2, 1, 3, 4)) # (b c t h w)
results = mint.reshape(results, (b, t, *results.shape[1:])) # (b t c h w)
results = mint.permute(results, (0, 2, 1, 3, 4)) # (b c t h w)

return results

Expand Down Expand Up @@ -402,7 +401,7 @@ def get_latents_2d(self, x):
B, C, H, W = x.shape
if C != 3:
# b h w c -> b c h w
x = ops.transpose(x, (0, 3, 1, 2))
x = mint.permute(x, (0, 3, 1, 2))

z = ops.stop_gradient(self.scale_factor * self.first_stage_model.encode(x))

Expand All @@ -413,13 +412,13 @@ def get_latents(self, x):
B, F, C, H, W = x.shape
if C != 3:
raise ValueError("Expect input shape (b f 3 h w), but get {}".format(x.shape))
x = ops.reshape(x, (-1, C, H, W))
x = mint.reshape(x, (-1, C, H, W))

z = ops.stop_gradient(self.scale_factor * self.first_stage_model.encode(x))

# (b*f c h w) -> (b f c h w) -> (b c f h w )
z = ops.reshape(z, (B, F, z.shape[1], z.shape[2], z.shape[3]))
z = ops.transpose(z, (0, 2, 1, 3, 4))
z = mint.reshape(z, (B, F, z.shape[1], z.shape[2], z.shape[3]))
z = mint.permute(z, (0, 2, 1, 3, 4))

return z

Expand Down Expand Up @@ -469,7 +468,7 @@ def construct(self, x: ms.Tensor, text_tokens: ms.Tensor, control=None, **kwargs
t = self.uniform_int(
(x.shape[0],), Tensor(0, dtype=mstype.int32), Tensor(self.num_timesteps, dtype=mstype.int32)
)
noise = ops.randn_like(z)
noise = mint.randn_like(z)
noisy_latents, snr = self.add_noise(z, noise, t)

# 3. get condition embeddings
Expand All @@ -493,11 +492,11 @@ def construct(self, x: ms.Tensor, text_tokens: ms.Tensor, control=None, **kwargs
loss_sample = self.reduce_loss(loss_element)

if self.snr_gamma is not None:
snr_gamma = ops.ones_like(snr) * self.snr_gamma
snr_gamma = mint.ones_like(snr) * self.snr_gamma
# TODO: for v-pred, .../ (snr+1)
# TODO: for beta zero rescale, consider snr=0
# min{snr, gamma} / snr
loss_weight = ops.stack((snr, snr_gamma), axis=0).min(axis=0) / snr
loss_weight = mint.stack((snr, snr_gamma), dim=0).min(axis=0) / snr
loss = (loss_weight * loss_sample).mean()
else:
loss = loss_sample.mean()
Expand All @@ -506,7 +505,7 @@ def construct(self, x: ms.Tensor, text_tokens: ms.Tensor, control=None, **kwargs
"""
# can be used to place more weights to high-score samples
logvar_t = self.logvar[t]
loss = loss_simple / ops.exp(logvar_t) + logvar_t
loss = loss_simple / mint.exp(logvar_t) + logvar_t
loss = self.l_simple_weight * loss.mean()
"""

Expand Down Expand Up @@ -542,7 +541,7 @@ def get_latents(self, x):
z = ops.stop_gradient(self.scale_factor * x)

# (b f c h w) -> (b c f h w )
z = ops.transpose(z, (0, 2, 1, 3, 4))
z = mint.permute(z, (0, 2, 1, 3, 4))
return z

def get_condition_embeddings(self, text_tokens, control=None):
Expand Down Expand Up @@ -596,13 +595,13 @@ def construct(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, **kwargs)
if self.conditioning_key is None:
out = self.diffusion_model(x, t, **kwargs)
elif self.conditioning_key == "concat":
x_concat = ops.concat((x, c_concat), axis=1)
x_concat = mint.cat((x, c_concat), dim=1)
out = self.diffusion_model(x_concat, t, **kwargs)
elif self.conditioning_key == "crossattn": # t2v task
context = c_crossattn
out = self.diffusion_model(x, t, context=context, **kwargs)
elif self.conditioning_key == "hybrid":
x_concat = ops.concat((x, c_concat), axis=1)
x_concat = mint.cat((x, c_concat), dim=1)
context = c_crossattn
out = self.diffusion_model(x_concat, t, context=context, **kwargs)
elif self.conditioning_key == "crossattn-adm":
Expand Down
27 changes: 13 additions & 14 deletions examples/dynamicrafter/lvdm/models/samplers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tqdm import tqdm

import mindspore as ms
import mindspore.ops as ops
from mindspore import mint

from mindone.utils.misc import extract_into_tensor

Expand All @@ -14,7 +14,6 @@ def __init__(self, model, schedule="linear", **kwargs):
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.split = ops.Split(0, 2)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True):
self.ddim_timesteps = make_ddim_timesteps(
Expand All @@ -28,18 +27,18 @@ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0,

if self.model.use_dynamic_rescale:
self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
self.ddim_scale_arr_prev = ops.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]])
self.ddim_scale_arr_prev = mint.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]])

self.betas = self.model.betas
self.alphas_cumprod = self.model.alphas_cumprod
self.alphas_cumprod_prev = self.model.alphas_cumprod_prev

# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = ops.sqrt(alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = ops.sqrt(1.0 - alphas_cumprod)
self.log_one_minus_alphas_cumprod = ops.log(1.0 - alphas_cumprod)
self.sqrt_recip_alphas_cumprod = ops.sqrt(1.0 / alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = ops.sqrt(1.0 / alphas_cumprod - 1)
self.sqrt_alphas_cumprod = mint.sqrt(alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = mint.sqrt(1.0 - alphas_cumprod)
self.log_one_minus_alphas_cumprod = mint.log(1.0 - alphas_cumprod)
self.sqrt_recip_alphas_cumprod = mint.sqrt(1.0 / alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = mint.sqrt(1.0 / alphas_cumprod - 1)

# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
Expand All @@ -49,8 +48,8 @@ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0,
self.ddim_sigmas = ddim_sigmas
self.ddim_alphas = ddim_alphas
self.ddim_alphas_prev = ddim_alphas_prev
self.ddim_sqrt_one_minus_alphas = ops.sqrt(1.0 - ddim_alphas)
sigmas_for_original_sampling_steps = ddim_eta * ops.sqrt(
self.ddim_sqrt_one_minus_alphas = mint.sqrt(1.0 - ddim_alphas)
sigmas_for_original_sampling_steps = ddim_eta * mint.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
Expand Down Expand Up @@ -309,8 +308,8 @@ def p_sample_ddim(
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)

if self.model.use_dynamic_rescale:
scale_t = ops.full(size, self.ddim_scale_arr[index], dtype=self.ddim_scale_arr[index].dtype)
prev_scale_t = ops.full(size, self.ddim_scale_arr_prev[index], dtype=self.ddim_scale_arr_prev[index].dtype)
scale_t = mint.full(size, self.ddim_scale_arr[index], dtype=self.ddim_scale_arr[index].dtype)
prev_scale_t = mint.full(size, self.ddim_scale_arr_prev[index], dtype=self.ddim_scale_arr_prev[index].dtype)
rescale = prev_scale_t / scale_t
pred_x0 *= rescale

Expand All @@ -321,7 +320,7 @@ def p_sample_ddim(
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, repeat_noise) * temperature
if noise_dropout > 0.0:
noise, _ = ops.dropout(noise, p=noise_dropout)
noise, _ = mint.nn.Dropout(p=noise_dropout)(noise)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise

return x_prev, pred_x0
Expand Down Expand Up @@ -392,7 +391,7 @@ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = ops.sqrt(self.ddim_alphas)
sqrt_alphas_cumprod = mint.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas

if noise is None:
Expand Down
Loading