-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate inference 720p video with 24G VRAM #597
Open
narrowsnap
wants to merge
5
commits into
hpcaitech:main
Choose a base branch
from
narrowsnap:feature/720p_for_16g
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/bin/bash | ||
set_default_params() { | ||
num_frames=${1:-"4s"} | ||
resolution=${2:-"720p"} | ||
aspect_ratio=${3:-"9:16"} | ||
prompt=${4:-"Create a video featuring Will Smith enjoying a plate of spaghetti."} | ||
} | ||
|
||
set_default_params "$@" | ||
|
||
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 --master_port=23456 scripts/separate_inference/inference_text_encoder.py configs/opensora-v1-2/inference/sample.py --aes 7 --num-frames "$num_frames" --resolution "$resolution" --aspect-ratio "$aspect_ratio" --prompt "$prompt" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will by default use 2 GPUs, can you make this configurable via bash argument as well? |
||
if echo "$prompt" | grep -q "reference_path"; then | ||
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 --master_port=23456 scripts/separate_inference/inference_vae_encoder.py configs/opensora-v1-2/inference/sample.py --aes 7 --num-frames "$num_frames" --resolution "$resolution" --aspect-ratio "$aspect_ratio" --prompt "$prompt" | ||
fi | ||
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 --master_port=23456 scripts/separate_inference/inference_stdit.py configs/opensora-v1-2/inference/sample.py --aes 7 --num-frames "$num_frames" --resolution "$resolution" --aspect-ratio "$aspect_ratio" --prompt "$prompt" | ||
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 --master_port=23456 scripts/separate_inference/inference_vae_decoder.py configs/opensora-v1-2/inference/sample.py --aes 7 --num-frames "$num_frames" --resolution "$resolution" --aspect-ratio "$aspect_ratio" --prompt "$prompt" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
import os | ||
from datetime import date, datetime, timedelta | ||
from pprint import pformat | ||
|
||
import colossalai | ||
import torch | ||
import torch.distributed as dist | ||
from colossalai.cluster import DistCoordinator | ||
from mmengine.runner import set_random_seed | ||
from tqdm import tqdm | ||
|
||
from opensora.acceleration.parallel_states import set_sequence_parallel_group | ||
from opensora.datasets.aspect import get_image_size, get_num_frames | ||
from opensora.models.text_encoder.t5 import text_preprocessing | ||
from opensora.registry import MODELS, SCHEDULERS, build_module | ||
from opensora.utils.config_utils import parse_configs | ||
from opensora.utils.inference_utils import ( | ||
append_score_to_prompts, | ||
apply_mask_strategy, | ||
extract_json_from_prompts, | ||
extract_prompts_loop, | ||
get_save_path_name, | ||
load_prompts, | ||
merge_prompt, | ||
prepare_multi_resolution_info, | ||
refine_prompts_by_openai, | ||
split_prompt, | ||
) | ||
from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype | ||
|
||
|
||
def main(): | ||
torch.set_grad_enabled(False) | ||
# ====================================================== | ||
# configs & runtime variables | ||
# ====================================================== | ||
# == parse configs == | ||
cfg = parse_configs(training=False) | ||
|
||
# == device and dtype == | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
cfg_dtype = cfg.get("dtype", "fp32") | ||
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" | ||
dtype = to_torch_dtype(cfg.get("dtype", "bf16")) | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
|
||
# == init distributed env == | ||
if is_distributed(): | ||
colossalai.launch_from_torch({}) | ||
coordinator = DistCoordinator() | ||
enable_sequence_parallelism = coordinator.world_size > 1 | ||
if enable_sequence_parallelism: | ||
set_sequence_parallel_group(dist.group.WORLD) | ||
else: | ||
coordinator = None | ||
enable_sequence_parallelism = False | ||
set_random_seed(seed=cfg.get("seed", 1024)) | ||
|
||
# == init logger == | ||
logger = create_logger() | ||
logger.info("Inference configuration:\n %s", pformat(cfg.to_dict())) | ||
verbose = cfg.get("verbose", 1) | ||
progress_wrap = tqdm if verbose == 1 else (lambda x: x) | ||
|
||
# ====================================================== | ||
# build model & load weights | ||
# ====================================================== | ||
logger.info("Building models...") | ||
|
||
# == prepare video size == | ||
image_size = cfg.get("image_size", None) | ||
if image_size is None: | ||
resolution = cfg.get("resolution", None) | ||
aspect_ratio = cfg.get("aspect_ratio", None) | ||
assert ( | ||
resolution is not None and aspect_ratio is not None | ||
), "resolution and aspect_ratio must be provided if image_size is not provided" | ||
image_size = get_image_size(resolution, aspect_ratio) | ||
num_frames = get_num_frames(cfg.num_frames) | ||
|
||
# == build diffusion model == | ||
(num_frames, *image_size) | ||
|
||
# == get vae temporal size == | ||
micro_frame_size = cfg.get("micro_frame_size", 17) | ||
time_padding = 0 if micro_frame_size % 4 == 0 else 4 - micro_frame_size % 4 | ||
lsize = (micro_frame_size + time_padding) // 4 | ||
frame_size = lsize * (num_frames // micro_frame_size) | ||
remain_temporal_size = num_frames % micro_frame_size | ||
if remain_temporal_size > 0: | ||
time_padding = 0 if remain_temporal_size % 4 == 0 else 4 - remain_temporal_size % 4 | ||
remain_size = (remain_temporal_size + time_padding) // 4 | ||
frame_size += remain_size | ||
|
||
latent_size = (frame_size, image_size[0] // 8, image_size[1] // 8) | ||
model = ( | ||
build_module( | ||
cfg.model, | ||
MODELS, | ||
input_size=latent_size, | ||
in_channels=4, | ||
caption_channels=4096, | ||
model_max_length=300, | ||
enable_sequence_parallelism=enable_sequence_parallelism, | ||
) | ||
.to(device, dtype) | ||
.eval() | ||
) | ||
# == build scheduler == | ||
scheduler = build_module(cfg.scheduler, SCHEDULERS) | ||
|
||
# ====================================================== | ||
# inference | ||
# ====================================================== | ||
# == load prompts == | ||
prompts = cfg.get("prompt", None) | ||
start_idx = cfg.get("start_index", 0) | ||
if prompts is None: | ||
if cfg.get("prompt_path", None) is not None: | ||
prompts = load_prompts(cfg.prompt_path, start_idx, cfg.get("end_index", None)) | ||
else: | ||
prompts = [cfg.get("prompt_generator", "")] * 1_000_000 # endless loop | ||
|
||
# == prepare reference == | ||
reference_path = cfg.get("reference_path", [""] * len(prompts)) | ||
mask_strategy = cfg.get("mask_strategy", [""] * len(prompts)) | ||
assert len(reference_path) == len(prompts), "Length of reference must be the same as prompts" | ||
assert len(mask_strategy) == len(prompts), "Length of mask_strategy must be the same as prompts" | ||
|
||
# == prepare arguments == | ||
fps = cfg.fps | ||
cfg.get("save_fps", fps // cfg.get("frame_interval", 1)) | ||
multi_resolution = cfg.get("multi_resolution", None) | ||
batch_size = cfg.get("batch_size", 1) | ||
num_sample = cfg.get("num_sample", 1) | ||
loop = cfg.get("loop", 1) | ||
cfg.get("condition_frame_length", 5) | ||
cfg.get("condition_frame_edit", 0.0) | ||
align = cfg.get("align", None) | ||
|
||
save_dir = cfg.save_dir | ||
os.makedirs(save_dir, exist_ok=True) | ||
sample_name = cfg.get("sample_name", None) | ||
prompt_as_path = cfg.get("prompt_as_path", False) | ||
|
||
# == prepare saved dir == | ||
cur_date = datetime.now().strftime("%Y-%m-%d") | ||
if not os.path.exists(os.path.join(save_dir, cur_date)): | ||
yesterday = date.today() - timedelta(days=1) | ||
cur_date = yesterday.strftime("%Y-%m-%d") | ||
latest_idx = sorted([int(x) for x in os.listdir(os.path.join(save_dir, cur_date))])[-1] | ||
saved_idx = str(latest_idx).zfill(5) | ||
|
||
# == Iter over all samples == | ||
for i in progress_wrap(range(0, len(prompts), batch_size)): | ||
# == prepare batch prompts == | ||
batch_prompts = prompts[i : i + batch_size] | ||
ms = mask_strategy[i : i + batch_size] | ||
refs = reference_path[i : i + batch_size] | ||
|
||
# == get json from prompts == | ||
batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms) | ||
original_batch_prompts = batch_prompts | ||
|
||
# == get reference for condition == | ||
# refs = collect_references_batch(refs, vae, image_size) | ||
refs_x = [] # refs_x: [batch, ref_num, C, T, H, W] | ||
for reference_path in refs: | ||
if reference_path == "": | ||
refs_x.append([]) | ||
continue | ||
ref_path = reference_path.split(";") | ||
ref = [] | ||
for r_path in ref_path: | ||
ref.append("placehold") | ||
refs_x.append(ref) | ||
refs = refs_x | ||
|
||
# == multi-resolution info == | ||
model_args = prepare_multi_resolution_info( | ||
multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype | ||
) | ||
|
||
# == Iter over number of sampling for one prompt == | ||
for k in range(num_sample): | ||
# == prepare save paths == | ||
save_paths = [ | ||
get_save_path_name( | ||
save_dir, | ||
sample_name=sample_name, | ||
sample_idx=start_idx + idx, | ||
prompt=original_batch_prompts[idx], | ||
prompt_as_path=prompt_as_path, | ||
num_sample=num_sample, | ||
k=k, | ||
) | ||
for idx in range(len(batch_prompts)) | ||
] | ||
|
||
# NOTE: Skip if the sample already exists | ||
# This is useful for resuming sampling VBench | ||
if prompt_as_path and all_exists(save_paths): | ||
continue | ||
|
||
# == process prompts step by step == | ||
# 0. split prompt | ||
# each element in the list is [prompt_segment_list, loop_idx_list] | ||
batched_prompt_segment_list = [] | ||
batched_loop_idx_list = [] | ||
for prompt in batch_prompts: | ||
prompt_segment_list, loop_idx_list = split_prompt(prompt) | ||
batched_prompt_segment_list.append(prompt_segment_list) | ||
batched_loop_idx_list.append(loop_idx_list) | ||
|
||
# 1. refine prompt by openai | ||
if cfg.get("llm_refine", False): | ||
# only call openai API when | ||
# 1. seq parallel is not enabled | ||
# 2. seq parallel is enabled and the process is rank 0 | ||
if not enable_sequence_parallelism or (enable_sequence_parallelism and is_main_process()): | ||
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): | ||
batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list) | ||
|
||
# sync the prompt if using seq parallel | ||
if enable_sequence_parallelism: | ||
coordinator.block_all() | ||
prompt_segment_length = [ | ||
len(prompt_segment_list) for prompt_segment_list in batched_prompt_segment_list | ||
] | ||
|
||
# flatten the prompt segment list | ||
batched_prompt_segment_list = [ | ||
prompt_segment | ||
for prompt_segment_list in batched_prompt_segment_list | ||
for prompt_segment in prompt_segment_list | ||
] | ||
|
||
# create a list of size equal to world size | ||
broadcast_obj_list = [batched_prompt_segment_list] * coordinator.world_size | ||
dist.broadcast_object_list(broadcast_obj_list, 0) | ||
|
||
# recover the prompt list | ||
batched_prompt_segment_list = [] | ||
segment_start_idx = 0 | ||
all_prompts = broadcast_obj_list[0] | ||
for num_segment in prompt_segment_length: | ||
batched_prompt_segment_list.append( | ||
all_prompts[segment_start_idx : segment_start_idx + num_segment] | ||
) | ||
segment_start_idx += num_segment | ||
|
||
# 2. append score | ||
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): | ||
batched_prompt_segment_list[idx] = append_score_to_prompts( | ||
prompt_segment_list, | ||
aes=cfg.get("aes", None), | ||
flow=cfg.get("flow", None), | ||
camera_motion=cfg.get("camera_motion", None), | ||
) | ||
|
||
# 3. clean prompt with T5 | ||
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): | ||
batched_prompt_segment_list[idx] = [text_preprocessing(prompt) for prompt in prompt_segment_list] | ||
|
||
# 4. merge to obtain the final prompt | ||
batch_prompts = [] | ||
for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list): | ||
batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list)) | ||
|
||
# == Iter over loop generation == | ||
for loop_i in range(loop): | ||
# == get prompt for loop i == | ||
batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i) | ||
|
||
# == add condition frames for loop == | ||
if os.path.exists(os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_ref.pt")): | ||
ref = torch.load(os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_ref.pt")) | ||
ref = ref.to(dtype) | ||
ref = ref.to(device) | ||
refs[i][loop_i] = ref | ||
ms[i] = open(os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_ms")).readlines()[0].strip() | ||
|
||
# == get text embedding == | ||
caption_embs = torch.load(os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_prompt.pt")) | ||
caption_emb_masks = torch.load( | ||
os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_prompt_masks.pt") | ||
) | ||
caption_embs = caption_embs.to(device, torch.float32) | ||
caption_emb_masks = caption_emb_masks.to(device, torch.int64) | ||
|
||
# == sampling == | ||
torch.manual_seed(1024) | ||
z = torch.randn(len(batch_prompts), 4, *latent_size, device=device, dtype=dtype) | ||
masks = apply_mask_strategy(z, refs, ms, loop_i, align=align) | ||
samples = scheduler.sample( | ||
model, | ||
text_encoder=None, | ||
z=z, | ||
prompts=batch_prompts_loop, | ||
device=device, | ||
additional_args=model_args, | ||
progress=verbose >= 2, | ||
mask=masks, | ||
caption_embs=caption_embs, | ||
caption_emb_masks=caption_emb_masks, | ||
) | ||
if is_main_process(): | ||
torch.save(samples.cpu(), os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_latents.pt")) | ||
start_idx += len(batch_prompts) | ||
logger.info("Inference STDiT finished.") | ||
logger.info("Saved %s samples to %s/%s/%s", start_idx, save_dir, cur_date, saved_idx) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am aware of your motivation, but can you add more doc to tell other users why, when and how to run inference separately so that they can feel more guided?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.