From 88848311539ee782f1c9484c1ede14d7f0b72049 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Sun, 17 Mar 2024 12:17:28 +0800 Subject: [PATCH] updated doc (#77) --- README.md | 8 ++++---- docs/commands.md | 17 ++++++++++++----- docs/structure.md | 2 ++ opensora/utils/ckpt_utils.py | 25 ++++--------------------- 4 files changed, 22 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 876da1b5b..a34ae3063 100644 --- a/README.md +++ b/README.md @@ -120,16 +120,16 @@ Our model's weight is partially initialized from [PixArt-α](https://github.com/ To run inference with our provided weights, first download [T5](https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main) weights into `pretrained_models/t5_ckpts/t5-v1_1-xxl`. Then run the following commands to generate samples. See [here](docs/structure.md#inference-config-demos) to customize the configuration. ```bash -# Sample 16x256x256 (~2s) +# Sample 16x256x256 (may take less than 1 min) torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path ./path/to/your/ckpt.pth -# Sample 16x512x512 (~2s) +# Sample 16x512x512 (may take less than 1 min) torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x512x512.py -# Sample 64x512x512 (~5s) +# Sample 64x512x512 (may take 1 min or more) torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/64x512x512.py -# Sample 64x512x512 with sequence parallelism(~5s) +# Sample 64x512x512 with sequence parallelism (may take 1 min or more) # sequence parallelism is enabled automatically when nproc_per_node is larger than 1 torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/64x512x512.py ``` diff --git a/docs/commands.md b/docs/commands.md index 3e4dece21..28ee285de 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -26,11 +26,13 @@ Download T5 into `./pretrained_models` and run the following command. ```bash # 256x256 -python scripts/inference.py configs/pixart/inference/1x256x256.py --ckpt-path PixArt-XL-2-256x256.pth +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/pixart/inference/1x256x256.py --ckpt-path PixArt-XL-2-256x256.pth + # 512x512 -python scripts/inference.py configs/pixart/inference/1x512x512.py --ckpt-path PixArt-XL-2-512x512.pth +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/pixart/inference/1x512x512.py --ckpt-path PixArt-XL-2-512x512.pth + # 1024 multi-scale -python scripts/inference.py configs/pixart/inference/1x1024MS.py --ckpt-path PixArt-XL-2-1024MS.pth +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/pixart/inference/1x1024MS.py --ckpt-path PixArt-XL-2-1024MS.pth ``` ### Inference with checkpoints saved during training @@ -39,9 +41,14 @@ During training, an experiment logging folder is created in `outputs` directory. ```bash # inference with ema model -python scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000/ema.pt +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000/ema.pt + # inference with model -python scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000 +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000 + +# inference with sequence parallelism +# sequence parallelism is enabled automatically when nproc_per_node is larger than 1 +torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000 ``` The second command will automatically generate a `model_ckpt.pt` file in the checkpoint folder. diff --git a/docs/structure.md b/docs/structure.md index b0bba8fd2..0fc087ee4 100644 --- a/docs/structure.md +++ b/docs/structure.md @@ -88,6 +88,8 @@ model = dict( type="STDiT-XL/2", # Select model type (STDiT-XL/2, DiT-XL/2, etc.) space_scale=1.0, # (Optional) Space positional encoding scale (new height / old height) time_scale=2 / 3, # (Optional) Time positional encoding scale (new frame_interval / old frame_interval) + enable_flashattn=True, # (Optional) Speed up training and inference with flash attention + enable_layernorm_kernel=True, # (Optional) Speed up training and inference with fused kernel from_pretrained="PRETRAINED_MODEL", # (Optional) Load from pretrained model no_temporal_pos_emb=True, # (Optional) Disable temporal positional encoding (for image) ) diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index 3543f01fc..27adfba19 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -10,7 +10,7 @@ import torch.distributed as dist import torch.nn as nn from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.cluster import DistCoordinator from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler @@ -78,26 +78,8 @@ def download_model(model_name): def load_from_sharded_state_dict(model, ckpt_path): - # TODO: harded-coded for colossal loading - os.environ["RANK"] = "0" - os.environ["LOCAL_RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" - colossalai.launch_from_torch({}) - plugin = LowLevelZeroPlugin( - stage=2, - precision="fp32", - initial_scale=2**16, - ) - booster = Booster(plugin=plugin) - model, _, _, _, _ = booster.boost(model=model) - booster.load_model(model, os.path.join(ckpt_path, "model")) - - save_path = os.path.join(ckpt_path, "model_ckpt.pt") - torch.save(model.module.state_dict(), save_path) - print(f"Model checkpoint saved to {save_path}") - + ckpt_io = GeneralCheckpointIO() + ckpt_io.load_model(model, os.path.join(ckpt_path, "model")) def model_sharding(model: torch.nn.Module): global_rank = dist.get_rank() @@ -229,5 +211,6 @@ def load_checkpoint(model, ckpt_path, save_as_pt=True): if save_as_pt: save_path = os.path.join(ckpt_path, "model_ckpt.pt") torch.save(model.state_dict(), save_path) + print(f"Model checkpoint saved to {save_path}") else: raise ValueError(f"Invalid checkpoint path: {ckpt_path}")