Skip to content

Commit

Permalink
[not for land yet] hack max and abs out of ops eligible for AC
Browse files Browse the repository at this point in the history
Summary:

For now, this is not for land and just saving work and starting a
discussion.

We need to calculate max(abs(tensor)) for each float8 gemm input
when using per-tensor scaling.  I realized that today this does
not work efficiently with AC, because max(abs(tensor)) is usually
recomputed. Since the output size is 1, it's more efficient to
save it and never recompute.

For now, just hack these ops into the do-not-recompute list
to get a perf measurement. Seems to save ~1% on LLaMa 3B on 8 H100 GPUs.
I verified in the pre-post traces that the redundant triton kernels
to calculate max(abs(activation)) and max(abs(weight)) are gone with
this hack.

Heading to PTC but we should get a measurement on a larger model, and
figure out a better way to land this.

Test Plan:

https://gist.github.com/vkuzo/375230e30e1cb599ad31a87e0be25d75

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Sep 17, 2024
1 parent d2a4904 commit 3e33e82
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ def apply_tp(
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
# Not for land in the current state, need to align on best way to expose this
# for various AC options. For now just hack it in here to get a clean
# measurement.
torch.ops.aten.abs.default,
torch.ops.aten.max.default,
}


Expand Down

0 comments on commit 3e33e82

Please sign in to comment.