Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for land] Integrate float8nocompile, an experimental feature for high performance #778

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link

@danielvegamyhre danielvegamyhre commented Jan 7, 2025

Summary

This PR contains a proof-of-concept integration of an experimental feature I've been working on in torchao: float8nocompile (official name TBD, naming things is hard!).

It is an implementation of float8 conversion with tensorwise dynamic scaling that uses handwritten Triton kernels to achieve high performance, rather than requiring torch.compile.

Benchmarking training performance

Model: Llama3 8b on 4 H100s with batch size of 1, seq len of 8192.

Configuration TFLOPS Tokens/sec Peak memory usage
bfloat16, eager mode 290.5682 5017.2 64.80 GB
float8nocompile, eager mode 324.4117 5601.6 64.86 GB
bfloat16, torch.compile 326.3106 5634.6 62.85 GB
float8, torch.compile, without float8 FSDP all-gather 386.5803 6674.8 62.91 GB

Benchmarking single linear layer forward+backward performance

Tested used a single linear layer of size (4096,4096) with different input sizes.

Performance benchmarks show the float8nocompile implementation is beating torch.compile by 1.72-4.45% depending on the input tensor size.

input_shape    kernel_algo                 high_precision_dtype      eager_time    compiled_time    float8nocompile
-------------  --------------------------  ----------------------  ------------  ---------------  -----------------
(16, 4096)     KernelAlgorithm.ATOMIC_MAX  torch.bfloat16               649.218          394.725            386.469
(256, 4096)    KernelAlgorithm.ATOMIC_MAX  torch.bfloat16               685.783          420.743            408.137
(4096, 4096)   KernelAlgorithm.ATOMIC_MAX  torch.bfloat16              1829.13          1053.64             977.858
(65536, 4096)  KernelAlgorithm.ATOMIC_MAX  torch.bfloat16             21554.2          12369.7            10813.3
(16, 4096)     KernelAlgorithm.REDUCTION   torch.bfloat16               650.026          394.951            696.221
(256, 4096)    KernelAlgorithm.REDUCTION   torch.bfloat16               684.865          421.144            729.459
(4096, 4096)   KernelAlgorithm.REDUCTION   torch.bfloat16              1826.42          1050.85            1596.12
(65536, 4096)  KernelAlgorithm.REDUCTION   torch.bfloat16             21584.7          12347.2            17290

TL;DR is right now running in eager mode but using the triton kernels for fp8 conversion, we achieve:

  • Better performance than bf16 in eager mode
  • Similar performance to bf16 + torch.compile
  • Lower performance than fp8 + torch.compile

i think this makes sense since fp8 + torch.compile is compiling the entire model, not just the fp8 conversion, so I would expect the perf to be better.

Note: this PR depends on this stack of PRs being merged into torchao, and those changes being included into a release (which the user installs).

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 7, 2025
@danielvegamyhre
Copy link
Author

@vkuzo here is the PoC of the torchtitan + float8nocompile integration, and the training performance benchmarking results

@vkuzo
Copy link
Contributor

vkuzo commented Jan 7, 2025

I think a good way to go here is to mark this PR "not for land" for now

Benchmarking mean MFU during training run

Here is the data I think is important:

  1. in addition to mfu, also report tokens per second and peak memory usage as top line metrics
  2. report the metrics for the following experiments:
    2a. bfloat16, eager mode
    2b. float8nocompile, eager mode
    2c. bfloat16, torch.compile
    2d. float8, torch.compile, without float8 FSDP all-gather

@danielvegamyhre danielvegamyhre changed the title [PoC] Integrate float8nocompile, an experimental feature for high performance [Not for land] Integrate float8nocompile, an experimental feature for high performance Jan 7, 2025
@awgu
Copy link
Contributor

awgu commented Jan 7, 2025

I think we should take care to mention/keep in mind that the MFU is with respect to peak bf16 TFLOPS. Direct comparison of TFLOPS might make more sense when comparing bf16 vs. fp8 runs since the fp8 runs are not actually doing every computation in fp8 either (e.g. SDPA or final linear).

@danielvegamyhre
Copy link
Author

I think we should take care to mention/keep in mind that the MFU is with respect to peak bf16 TFLOPS. Direct comparison of TFLOPS might make more sense when comparing bf16 vs. fp8 runs since the fp8 runs are not actually doing every computation in fp8 either (e.g. SDPA or final linear).

Makes sense, is there an existing way to log TFLOPS/sec during training w/ torchtitan? Searching around I don't see one

@awgu
Copy link
Contributor

awgu commented Jan 8, 2025

@danielvegamyhre I think the TFLOPS is the same as MFU without the peak TFLOPS denominator (and extra factor of 100):

mfu = 100 * num_flop_per_token * tps / gpu_peak_flops

In other words, it is just num_flop_per_token * tps. So you can convert your MFU numbers back to TFLOPS by multiplying by gpu_peak_flops / 100.

@danielvegamyhre
Copy link
Author

@danielvegamyhre I think the TFLOPS is the same as MFU without the peak TFLOPS denominator (and extra factor of 100):

mfu = 100 * num_flop_per_token * tps / gpu_peak_flops

In other words, it is just num_flop_per_token * tps. So you can convert your MFU numbers back to TFLOPS by multiplying by gpu_peak_flops / 100.

ah of course, thanks! Updated the PR description to include TFLOPS instead of MFU

@vkuzo
Copy link
Contributor

vkuzo commented Jan 8, 2025

based on the results shared so far, I think it would be interesting to add one additional experiment branch: production float8 + torch.compile on just the torch.nn.Linear layers. I have some old unlanded code with an example of how to apply this here: #661 . If we structure the handwritten kernels right for the chosen AC strategy, we should be able to match the level of performance in that setup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants