Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
Fwd and inv work for all sizes. LDE still spills registers.
Browse files Browse the repository at this point in the history
  • Loading branch information
mcarilli committed Nov 16, 2023
1 parent 49142e4 commit 4849b70
Show file tree
Hide file tree
Showing 6 changed files with 1,147 additions and 1,158 deletions.
3 changes: 2 additions & 1 deletion boojum-cuda/benches/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ fn case(
group.warm_up_time(Duration::from_millis(WARM_UP_TIME_MS));
group.measurement_time(Duration::from_millis(MEASUREMENT_TIME_MS));
group.sampling_mode(SamplingMode::Flat);
for inverse in [false, true] {
// for inverse in [false, true] {
for inverse in [false] {
for bitrev_inputs in [false, true] {
for log_count in log_n_range.clone() {
let count: u32 = 1 << log_count;
Expand Down
2 changes: 1 addition & 1 deletion boojum-cuda/native/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ set_target_properties(boojum-cuda-native PROPERTIES CUDA_SEPARABLE_COMPILATION O
set_target_properties(boojum-cuda-native PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_compile_options(boojum-cuda-native PRIVATE --expt-relaxed-constexpr)
target_compile_options(boojum-cuda-native PRIVATE --ptxas-options=-v)
#target_compile_options(boojum-cuda-native PRIVATE -lineinfo)
target_compile_options(boojum-cuda-native PRIVATE -lineinfo)
#target_compile_options(boojum-cuda-native PRIVATE --keep)
install(TARGETS boojum-cuda-native DESTINATION .)
77 changes: 71 additions & 6 deletions boojum-cuda/native/ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@ using namespace goldilocks;

namespace ntt {

#define PAD(X) (((X) >> 4) * 17 + ((X)&15))
static constexpr unsigned PADDED_WARP_SCRATCH_SIZE = (256 / 16) * 17 + 1;
// for debugging:
// #define PAD(X) (X)
// static constexpr unsigned PADDED_WARP_SCRATCH_SIZE = 256;

__device__ __forceinline__ void exchg_dit(base_field &a, base_field &b, const base_field &twiddle) {
b = base_field::mul(b, twiddle);
const auto a_tmp = a;
Expand Down Expand Up @@ -40,6 +34,77 @@ __device__ __forceinline__ base_field get_twiddle(const bool inverse, const unsi
return base_field::mul(fine, coarse);
}

DEVICE_FORCEINLINE void shfl_xor_bf(base_field *vals, const unsigned i, const unsigned lane_id, const unsigned lane_mask) {
// Some threads need to post vals[2 * i], others need to post vals[2 * i + 1].
// We use a temporary to avoid calling shfls divergently, which is unsafe on pre-Volta.
base_field tmp{};
if (lane_id & lane_mask)
tmp = vals[2 * i];
else
tmp = vals[2 * i + 1];
tmp[0] = __shfl_xor_sync(0xffffffff, tmp[0], lane_mask);
tmp[1] = __shfl_xor_sync(0xffffffff, tmp[1], lane_mask);
if (lane_id & lane_mask)
vals[2 * i] = tmp;
else
vals[2 * i + 1] = tmp;
}

template <unsigned VALS_PER_WARP, unsigned LOG_VALS_PER_THREAD>
DEVICE_FORCEINLINE void load_initial_twiddles_warp(base_field *twiddle_cache, const unsigned lane_id, const unsigned gmem_offset,
const bool inverse) {
// cooperatively loads all the twiddles this warp needs for intrawarp stages
base_field *twiddles_this_stage = twiddle_cache;
unsigned num_twiddles_this_stage = VALS_PER_WARP >> 1;
unsigned exchg_region_offset = gmem_offset >> 1;
#pragma unroll
for (unsigned stage = 0; stage < LOG_VALS_PER_THREAD; stage++) {
#pragma unroll
for (unsigned i = lane_id; i < num_twiddles_this_stage; i += 32) {
twiddles_this_stage[i] = get_twiddle(inverse, i + exchg_region_offset);
}
twiddles_this_stage += num_twiddles_this_stage;
num_twiddles_this_stage >>= 1;
exchg_region_offset >>= 1;
}

// loads final 31 twiddles with minimal divergence. pain.
const unsigned lz = __clz(lane_id);
const unsigned stage_offset = 5 - (32 - lz);
const unsigned mask = (1 << (32 - lz)) - 1;
if (lane_id > 0) {
exchg_region_offset >>= stage_offset;
twiddles_this_stage[lane_id^31] = get_twiddle(inverse, (lane_id^mask) + exchg_region_offset);
}

__syncwarp();
}

template <unsigned LOG_VALS_PER_THREAD>
DEVICE_FORCEINLINE void load_noninitial_twiddles_warp(base_field *twiddle_cache, const unsigned lane_id, const unsigned warp_id,
const unsigned block_exchg_region_offset, const bool inverse) {
// cooperatively loads all the twiddles this warp needs for intrawarp stages
static_assert(LOG_VALS_PER_THREAD <= 4);
constexpr unsigned NUM_INTRAWARP_STAGES = LOG_VALS_PER_THREAD + 1;

// tile size 16: num twiddles = vals per warp / 2 / 16 == vals per thread
unsigned num_twiddles_first_stage = 1 << LOG_VALS_PER_THREAD;
unsigned exchg_region_offset = block_exchg_region_offset + warp_id * num_twiddles_first_stage;

// loads 2 * num_twiddles_first_stage - 1 twiddles with minimal divergence. pain.
if (lane_id > 0 && lane_id < 2 * num_twiddles_first_stage) {
const unsigned lz = __clz(lane_id);
const unsigned stage_offset = NUM_INTRAWARP_STAGES - (32 - lz);
const unsigned mask = (1 << (32 - lz)) - 1;
exchg_region_offset >>= stage_offset;
twiddle_cache[lane_id^(2 * num_twiddles_first_stage - 1)] = get_twiddle(inverse, (lane_id^mask) + exchg_region_offset);
}

__syncwarp();
}

static __device__ constexpr unsigned NTTS_PER_BLOCK = 8;

#include "ntt_b2n.cuh"
#include "ntt_n2b.cuh"

Expand Down
Loading

0 comments on commit 4849b70

Please sign in to comment.