Skip to content

Commit

Permalink
updated doc (hpcaitech#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee authored Mar 17, 2024
1 parent 286904c commit 8884831
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 30 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,16 @@ Our model's weight is partially initialized from [PixArt-α](https://github.com/
To run inference with our provided weights, first download [T5](https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main) weights into `pretrained_models/t5_ckpts/t5-v1_1-xxl`. Then run the following commands to generate samples. See [here](docs/structure.md#inference-config-demos) to customize the configuration.

```bash
# Sample 16x256x256 (~2s)
# Sample 16x256x256 (may take less than 1 min)
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path ./path/to/your/ckpt.pth

# Sample 16x512x512 (~2s)
# Sample 16x512x512 (may take less than 1 min)
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x512x512.py

# Sample 64x512x512 (~5s)
# Sample 64x512x512 (may take 1 min or more)
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/64x512x512.py

# Sample 64x512x512 with sequence parallelism(~5s)
# Sample 64x512x512 with sequence parallelism (may take 1 min or more)
# sequence parallelism is enabled automatically when nproc_per_node is larger than 1
torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/64x512x512.py
```
Expand Down
17 changes: 12 additions & 5 deletions docs/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ Download T5 into `./pretrained_models` and run the following command.

```bash
# 256x256
python scripts/inference.py configs/pixart/inference/1x256x256.py --ckpt-path PixArt-XL-2-256x256.pth
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/pixart/inference/1x256x256.py --ckpt-path PixArt-XL-2-256x256.pth

# 512x512
python scripts/inference.py configs/pixart/inference/1x512x512.py --ckpt-path PixArt-XL-2-512x512.pth
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/pixart/inference/1x512x512.py --ckpt-path PixArt-XL-2-512x512.pth

# 1024 multi-scale
python scripts/inference.py configs/pixart/inference/1x1024MS.py --ckpt-path PixArt-XL-2-1024MS.pth
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/pixart/inference/1x1024MS.py --ckpt-path PixArt-XL-2-1024MS.pth
```

### Inference with checkpoints saved during training
Expand All @@ -39,9 +41,14 @@ During training, an experiment logging folder is created in `outputs` directory.

```bash
# inference with ema model
python scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000/ema.pt
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000/ema.pt

# inference with model
python scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000

# inference with sequence parallelism
# sequence parallelism is enabled automatically when nproc_per_node is larger than 1
torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000
```

The second command will automatically generate a `model_ckpt.pt` file in the checkpoint folder.
Expand Down
2 changes: 2 additions & 0 deletions docs/structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ model = dict(
type="STDiT-XL/2", # Select model type (STDiT-XL/2, DiT-XL/2, etc.)
space_scale=1.0, # (Optional) Space positional encoding scale (new height / old height)
time_scale=2 / 3, # (Optional) Time positional encoding scale (new frame_interval / old frame_interval)
enable_flashattn=True, # (Optional) Speed up training and inference with flash attention
enable_layernorm_kernel=True, # (Optional) Speed up training and inference with fused kernel
from_pretrained="PRETRAINED_MODEL", # (Optional) Load from pretrained model
no_temporal_pos_emb=True, # (Optional) Disable temporal positional encoding (for image)
)
Expand Down
25 changes: 4 additions & 21 deletions opensora/utils/ckpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.distributed as dist
import torch.nn as nn
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -78,26 +78,8 @@ def download_model(model_name):


def load_from_sharded_state_dict(model, ckpt_path):
# TODO: harded-coded for colossal loading
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
colossalai.launch_from_torch({})
plugin = LowLevelZeroPlugin(
stage=2,
precision="fp32",
initial_scale=2**16,
)
booster = Booster(plugin=plugin)
model, _, _, _, _ = booster.boost(model=model)
booster.load_model(model, os.path.join(ckpt_path, "model"))

save_path = os.path.join(ckpt_path, "model_ckpt.pt")
torch.save(model.module.state_dict(), save_path)
print(f"Model checkpoint saved to {save_path}")

ckpt_io = GeneralCheckpointIO()
ckpt_io.load_model(model, os.path.join(ckpt_path, "model"))

def model_sharding(model: torch.nn.Module):
global_rank = dist.get_rank()
Expand Down Expand Up @@ -229,5 +211,6 @@ def load_checkpoint(model, ckpt_path, save_as_pt=True):
if save_as_pt:
save_path = os.path.join(ckpt_path, "model_ckpt.pt")
torch.save(model.state_dict(), save_path)
print(f"Model checkpoint saved to {save_path}")
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")

0 comments on commit 8884831

Please sign in to comment.