Skip to content

Commit

Permalink
Update on "3d with fp8 in test runner"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
H-Huang committed Sep 5, 2024
2 parents 5cbf901 + e7f2db7 commit 93f1d91
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
1 change: 0 additions & 1 deletion .github/workflows/integration_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,5 @@ jobs:
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4
1 change: 0 additions & 1 deletion torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ def backward(ctx, dy):

M, N = dy.shape
dx = torch.empty_like(x)
dw = torch.empty_like(weight)

sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
Expand Down
12 changes: 12 additions & 0 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def build_pipeline_schedule(job_config, stages, loss_fn):
if n_microbatches is None:
n_microbatches = job_config.experimental.pipeline_parallel_degree

# Validation that the stages are compatible with the schedule
if isinstance(schedule_class, PipelineScheduleSingle):
if len(stages) != 1:
raise ValueError(
f"PipelineScheduleSingle requires exactly one stage, got {len(stages)}"
)
elif isinstance(schedule_class, PipelineScheduleMulti):
if len(stages) < 2:
raise ValueError(
f"PipelineScheduleMulti requires at least two stages, got {len(stages)}"
)

return schedule_class(
stages if looped_schedule else stages[0],
n_microbatches=n_microbatches,
Expand Down

0 comments on commit 93f1d91

Please sign in to comment.