Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[not for land yet] hack max and abs out of ops eligible for AC
Summary: For now, this is not for land and just saving work and starting a discussion. We need to calculate max(abs(tensor)) for each float8 gemm input when using per-tensor scaling. I realized that today this does not work efficiently with AC, because max(abs(tensor)) is usually recomputed. Since the output size is 1, it's more efficient to save it and never recompute. For now, just hack these ops into the do-not-recompute list to get a perf measurement. Seems to save ~1% on LLaMa 3B on 8 H100 GPUs. I verified in the pre-post traces that the redundant triton kernels to calculate max(abs(activation)) and max(abs(weight)) are gone with this hack. Heading to PTC but we should get a measurement on a larger model, and figure out a better way to land this. Test Plan: https://gist.github.com/vkuzo/375230e30e1cb599ad31a87e0be25d75 Reviewers: Subscribers: Tasks: Tags:
- Loading branch information