Part of the reason to do this is to go through and ensure that the synchronization scope of all ops is done consistently -- for example, the lack of a __syncwarp()on tma async wait's is not ideal.
This is also likely a good chance to go through and unify all of the async mechanisms -- we are currently roughly exposing all of PTX's baggage in how that works, and it probably could be hidden better.
My basic assumption for this todo is that we need to add:
- FP8 register tile support
- FP8 shared tile support.
- MMA instructions
- Loads and stores.
We don't need to provide support for:
- Maps, reductions, etc, since 8-bit instructions don't even seem to be supported.
- H100 features (WGMMA, TMA, etc.) All of these can be done. But, I think they should do after the first set.
This list is probably not quite comprehensive -- I'm sure there's something I've missed when thinking it through. But I think it's pretty close.
In common/base_types.cuh
:
- Add wrappers for the
__nv_fp8_e5m2
and__nv_fp8_e4m3
types, analogous tokittens::bf16
. - Register the FP8 types in the
packing
andconvertor
structs, if you want to support conversions.
Note: I was thinking about how to handle the fact that the relevant instruction is mma.sync.aligned.m16n8k16
. Option (1) is to template the width of the rt_base
tile based on the type. Option (2) is to leave it as-is and instead handle it at rt
level in the ops/warp/register/mma.cuh
. The latter may require some ungodly reinterpret_cast
s when the time comes but I suspect it will be less painful than (1). But I leave it up to you either way.
In types/register/rt_base.cuh
:
- Register packed fp8 types as allowable, so that they don't throw static asserts.
- If using option (1), add templating to make the base tile sometimes width 32. It unfortunately seems HMMA wants that, after all.
- Add pretty wrappers for the FP8 types.
In types/register/rt.cuh
:
- Add pretty wrappers
- If (1), may want to modify to handle different base tile widths.
In types/shared/st.cuh
:
- Add wrappers for whichever fp8 types you want to support. I recommend using unpacked types here, just for consistency.
In types/shared/st_layout.cuh
:
- Add an additional (type) template args for the shared_indexer struct, and specialize it for fp8. Since we're ignoring the H100, you only really need to specialize the
naive
layout and thexor_swizzle
.
In ops/warp/register/mma.cuh
:
- Add HMMA 16832 wrapper for FP8. See here for the relevant instructions.
- Template dot and mma to run on these tiles.
In ops/warp/register/conversions.cuh
:
- Add a transpose by unpacking, shuffling, and repacking. This will be horrendously slow compared to the builtin movmatrix for 16 bit types, but it is still worth making sure everything needed can be done.
In ops/warp/memory/global_to_register.cuh
, ops/warp/memory/shared_to_register.cuh
, and ops/warp/memory/global_to_shared.cuh
:
- Add additional templating / functions.
First, the testing infrastructure is written assuming bf16 in global, since that's been enough to test everything so far. But here it is really not. So, you'll want to start by modifying testing_commons.cuh
, and in particular adding a type template to the initialize
function, so that we can check fp8 loads and stores, too. The validate
function will also need to be modified, but it looks more straightforward to me.
How hard you want to go on tests is up to you, but a few that would be good to add fp8 test for and run include:
tests/warp/register/tile/mma.cu
tests/warp/memory/tile/*.cu
tests/warp/shared/conversuions.cuh
Just to make sure everything is working as it should!
One relatively easy thing to do after all of the above would be to modify the 4090_ker.cu
and its harness.impl
in examples/attn_fwd
to be templated including FP8, and see how much faster we can get attention to run on the 4090 in FP8!
The two things that would need to be done in addition are:
- TMA -- this is just adding templates to
src/warp/memory/tile/tma.cuh
- WGMMA -- templating
src/group/wgmma
. One thing to remember: the transpose flag does NOT work on QGMMA :/ -- this is the main reason we didn't put it in before.