diff --git a/boojum-cuda/benches/ntt.rs b/boojum-cuda/benches/ntt.rs index 5576ce5..07ee83a 100644 --- a/boojum-cuda/benches/ntt.rs +++ b/boojum-cuda/benches/ntt.rs @@ -47,7 +47,8 @@ fn case( group.warm_up_time(Duration::from_millis(WARM_UP_TIME_MS)); group.measurement_time(Duration::from_millis(MEASUREMENT_TIME_MS)); group.sampling_mode(SamplingMode::Flat); - for inverse in [false, true] { + // for inverse in [false, true] { + for inverse in [false] { for bitrev_inputs in [false, true] { for log_count in log_n_range.clone() { let count: u32 = 1 << log_count; diff --git a/boojum-cuda/native/CMakeLists.txt b/boojum-cuda/native/CMakeLists.txt index f587620..b6d7936 100644 --- a/boojum-cuda/native/CMakeLists.txt +++ b/boojum-cuda/native/CMakeLists.txt @@ -45,6 +45,6 @@ set_target_properties(boojum-cuda-native PROPERTIES CUDA_SEPARABLE_COMPILATION O set_target_properties(boojum-cuda-native PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_compile_options(boojum-cuda-native PRIVATE --expt-relaxed-constexpr) target_compile_options(boojum-cuda-native PRIVATE --ptxas-options=-v) -#target_compile_options(boojum-cuda-native PRIVATE -lineinfo) +target_compile_options(boojum-cuda-native PRIVATE -lineinfo) #target_compile_options(boojum-cuda-native PRIVATE --keep) install(TARGETS boojum-cuda-native DESTINATION .) diff --git a/boojum-cuda/native/ntt.cu b/boojum-cuda/native/ntt.cu index 15f4338..3ffa398 100644 --- a/boojum-cuda/native/ntt.cu +++ b/boojum-cuda/native/ntt.cu @@ -6,12 +6,6 @@ using namespace goldilocks; namespace ntt { -#define PAD(X) (((X) >> 4) * 17 + ((X)&15)) -static constexpr unsigned PADDED_WARP_SCRATCH_SIZE = (256 / 16) * 17 + 1; -// for debugging: -// #define PAD(X) (X) -// static constexpr unsigned PADDED_WARP_SCRATCH_SIZE = 256; - __device__ __forceinline__ void exchg_dit(base_field &a, base_field &b, const base_field &twiddle) { b = base_field::mul(b, twiddle); const auto a_tmp = a; @@ -40,6 +34,77 @@ __device__ __forceinline__ base_field get_twiddle(const bool inverse, const unsi return base_field::mul(fine, coarse); } +DEVICE_FORCEINLINE void shfl_xor_bf(base_field *vals, const unsigned i, const unsigned lane_id, const unsigned lane_mask) { + // Some threads need to post vals[2 * i], others need to post vals[2 * i + 1]. + // We use a temporary to avoid calling shfls divergently, which is unsafe on pre-Volta. + base_field tmp{}; + if (lane_id & lane_mask) + tmp = vals[2 * i]; + else + tmp = vals[2 * i + 1]; + tmp[0] = __shfl_xor_sync(0xffffffff, tmp[0], lane_mask); + tmp[1] = __shfl_xor_sync(0xffffffff, tmp[1], lane_mask); + if (lane_id & lane_mask) + vals[2 * i] = tmp; + else + vals[2 * i + 1] = tmp; +} + +template +DEVICE_FORCEINLINE void load_initial_twiddles_warp(base_field *twiddle_cache, const unsigned lane_id, const unsigned gmem_offset, + const bool inverse) { + // cooperatively loads all the twiddles this warp needs for intrawarp stages + base_field *twiddles_this_stage = twiddle_cache; + unsigned num_twiddles_this_stage = VALS_PER_WARP >> 1; + unsigned exchg_region_offset = gmem_offset >> 1; +#pragma unroll + for (unsigned stage = 0; stage < LOG_VALS_PER_THREAD; stage++) { +#pragma unroll + for (unsigned i = lane_id; i < num_twiddles_this_stage; i += 32) { + twiddles_this_stage[i] = get_twiddle(inverse, i + exchg_region_offset); + } + twiddles_this_stage += num_twiddles_this_stage; + num_twiddles_this_stage >>= 1; + exchg_region_offset >>= 1; + } + + // loads final 31 twiddles with minimal divergence. pain. + const unsigned lz = __clz(lane_id); + const unsigned stage_offset = 5 - (32 - lz); + const unsigned mask = (1 << (32 - lz)) - 1; + if (lane_id > 0) { + exchg_region_offset >>= stage_offset; + twiddles_this_stage[lane_id^31] = get_twiddle(inverse, (lane_id^mask) + exchg_region_offset); + } + + __syncwarp(); +} + +template +DEVICE_FORCEINLINE void load_noninitial_twiddles_warp(base_field *twiddle_cache, const unsigned lane_id, const unsigned warp_id, + const unsigned block_exchg_region_offset, const bool inverse) { + // cooperatively loads all the twiddles this warp needs for intrawarp stages + static_assert(LOG_VALS_PER_THREAD <= 4); + constexpr unsigned NUM_INTRAWARP_STAGES = LOG_VALS_PER_THREAD + 1; + + // tile size 16: num twiddles = vals per warp / 2 / 16 == vals per thread + unsigned num_twiddles_first_stage = 1 << LOG_VALS_PER_THREAD; + unsigned exchg_region_offset = block_exchg_region_offset + warp_id * num_twiddles_first_stage; + + // loads 2 * num_twiddles_first_stage - 1 twiddles with minimal divergence. pain. + if (lane_id > 0 && lane_id < 2 * num_twiddles_first_stage) { + const unsigned lz = __clz(lane_id); + const unsigned stage_offset = NUM_INTRAWARP_STAGES - (32 - lz); + const unsigned mask = (1 << (32 - lz)) - 1; + exchg_region_offset >>= stage_offset; + twiddle_cache[lane_id^(2 * num_twiddles_first_stage - 1)] = get_twiddle(inverse, (lane_id^mask) + exchg_region_offset); + } + + __syncwarp(); +} + +static __device__ constexpr unsigned NTTS_PER_BLOCK = 8; + #include "ntt_b2n.cuh" #include "ntt_n2b.cuh" diff --git a/boojum-cuda/native/ntt_b2n.cuh b/boojum-cuda/native/ntt_b2n.cuh index 69c16f1..a0288bd 100644 --- a/boojum-cuda/native/ntt_b2n.cuh +++ b/boojum-cuda/native/ntt_b2n.cuh @@ -1,90 +1,78 @@ #pragma once // also, this file should only be compiled in one compile unit because it has __global__ definitions -DEVICE_FORCEINLINE void shfl_xor_bf(base_field *vals, const unsigned i, const unsigned lane_id, const unsigned lane_mask) { - // Some threads need to post vals[2 * i], others need to post vals[2 * i + 1]. - // We use a temporary to avoid calling shfls divergently, which is unsafe on pre-Volta. - base_field tmp{}; - if (lane_id & lane_mask) - tmp = vals[2 * i]; - else - tmp = vals[2 * i + 1]; - tmp[0] = __shfl_xor_sync(0xffffffff, tmp[0], lane_mask); - tmp[1] = __shfl_xor_sync(0xffffffff, tmp[1], lane_mask); - if (lane_id & lane_mask) - vals[2 * i] = tmp; - else - vals[2 * i + 1] = tmp; -} - template DEVICE_FORCEINLINE void b2n_initial_stages_warp(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, - const unsigned log_n, const bool inverse, const unsigned blocks_per_ntt, const unsigned log_extension_degree, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, const unsigned coset_idx) { constexpr unsigned VALS_PER_THREAD = 1 << LOG_VALS_PER_THREAD; constexpr unsigned PAIRS_PER_THREAD = VALS_PER_THREAD >> 1; - constexpr unsigned WARPS_PER_BLOCK = 4; constexpr unsigned VALS_PER_WARP = 32 * VALS_PER_THREAD; - constexpr unsigned VALS_PER_BLOCK = 32 * VALS_PER_THREAD * WARPS_PER_BLOCK; + constexpr unsigned LOG_VALS_PER_BLOCK = 5 + LOG_VALS_PER_THREAD + 2; + constexpr unsigned VALS_PER_BLOCK = 1 << LOG_VALS_PER_BLOCK; __shared__ base_field smem[VALS_PER_BLOCK]; const unsigned lane_id{threadIdx.x & 31}; const unsigned warp_id{threadIdx.x >> 5}; - const unsigned ntt_idx = 0; // blockIdx.x / blocks_per_ntt; - const unsigned block_idx_in_ntt = blockIdx.x - ntt_idx * blocks_per_ntt; - const unsigned gmem_offset = ntt_idx * stride_between_input_arrays + VALS_PER_BLOCK * block_idx_in_ntt + VALS_PER_WARP * warp_id; - const base_field *gmem_in = gmem_inputs_matrix + gmem_offset; - base_field *gmem_out = gmem_outputs_matrix + gmem_offset; + const unsigned gmem_offset = VALS_PER_BLOCK * blockIdx.x + VALS_PER_WARP * warp_id; + const base_field *gmem_in = gmem_inputs_matrix + gmem_offset + NTTS_PER_BLOCK * stride_between_input_arrays * blockIdx.y; + base_field *gmem_out = gmem_outputs_matrix + gmem_offset + NTTS_PER_BLOCK * stride_between_output_arrays * blockIdx.y; auto twiddle_cache = smem + VALS_PER_WARP * warp_id; base_field vals[VALS_PER_THREAD]; -#pragma unroll - for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { - const auto in = memory::load_cs(reinterpret_cast(gmem_in + 64 * i + 2 * lane_id)); - vals[2 * i][0] = in.x; - vals[2 * i][1] = in.y; - vals[2 * i + 1][0] = in.z; - vals[2 * i + 1][1] = in.w; - } + load_initial_twiddles_warp(twiddle_cache, lane_id, gmem_offset, inverse); - // cooperatively loads all the twiddles this warp needs - base_field *twiddles_this_stage = twiddle_cache; - unsigned num_twiddles_this_stage = VALS_PER_WARP >> 1; - unsigned exchg_region_offset = gmem_offset >> 1; - for (unsigned stage = 0; stage < stages_this_launch; stage++) { + const unsigned bound = std::min(NTTS_PER_BLOCK, num_ntts - NTTS_PER_BLOCK * blockIdx.y); + for (unsigned ntt_idx = 0; ntt_idx < bound; + ntt_idx++, gmem_in += stride_between_input_arrays, gmem_out += stride_between_output_arrays) { #pragma unroll - for (unsigned i = lane_id; i < num_twiddles_this_stage; i += 32) { - twiddles_this_stage[i] = get_twiddle(inverse, i + exchg_region_offset); + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + const auto in = memory::load_cs(reinterpret_cast(gmem_in + 64 * i + 2 * lane_id)); + vals[2 * i][0] = in.x; + vals[2 * i][1] = in.y; + vals[2 * i + 1][0] = in.z; + vals[2 * i + 1][1] = in.w; } - twiddles_this_stage += num_twiddles_this_stage; - num_twiddles_this_stage >>= 1; - exchg_region_offset >>= 1; - } - __syncwarp(); - - unsigned lane_mask = 1; - twiddles_this_stage = twiddle_cache; - num_twiddles_this_stage = VALS_PER_WARP >> 1; - for (unsigned stage = 0; stage < 6; stage++) { +// if (log_extension_degree && !inverse) { +// if (coset_idx) { +// const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; +// const unsigned offset = coset_idx << shift; +// #pragma unroll +// for (unsigned i = 0; i < VALS_PER_THREAD; i++) { +// const unsigned idx = __brev(gmem_offset + 64 * (i >> 1) + 2 * lane_id + (i & 1)) >> (32 - log_n); +// auto power_of_w = get_power_of_w(idx * offset, false); +// vals[i] = base_field::mul(vals[i], power_of_w); +// } +// } +// #pragma unroll +// for (unsigned i = 0; i < VALS_PER_THREAD; i++) { +// const unsigned idx = __brev(gmem_offset + 64 * (i >> 1) + 2 * lane_id + (i & 1)) >> (32 - log_n); +// auto power_of_g = get_power_of_g(idx, false); +// vals[i] = base_field::mul(vals[i], power_of_g); +// } +// } + + unsigned lane_mask = 1; + base_field *twiddles_this_stage = twiddle_cache; + unsigned num_twiddles_this_stage = VALS_PER_WARP >> 1; + for (unsigned stage = 0; stage < 6; stage++) { #pragma unroll - for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { - const auto twiddle = twiddles_this_stage[(32 * i + lane_id) >> stage]; - exchg_dif(vals[2 * i], vals[2 * i + 1], twiddle); - if (stage < 5) - shfl_xor_bf(vals, i, lane_id, lane_mask); + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + const auto twiddle = twiddles_this_stage[(32 * i + lane_id) >> stage]; + exchg_dif(vals[2 * i], vals[2 * i + 1], twiddle); + if (stage < 5) + shfl_xor_bf(vals, i, lane_id, lane_mask); + } + lane_mask <<= 1; + twiddles_this_stage += num_twiddles_this_stage; + num_twiddles_this_stage >>= 1; } - lane_mask <<= 1; - twiddles_this_stage += num_twiddles_this_stage; - num_twiddles_this_stage >>= 1; - } -#pragma unroll - for (unsigned i = 1, stage = 6; i < LOG_VALS_PER_THREAD; i++, stage++) { - if (stage < stages_this_launch) { + for (unsigned i = 1; i < LOG_VALS_PER_THREAD; i++) { #pragma unroll for (unsigned j = 0; j < PAIRS_PER_THREAD >> i; j++) { const unsigned exchg_tile_sz = 2 << i; @@ -98,659 +86,454 @@ void b2n_initial_stages_warp(const base_field *gmem_inputs_matrix, base_field *g twiddles_this_stage += num_twiddles_this_stage; num_twiddles_this_stage >>= 1; } - } #pragma unroll - for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { - // This output pattern (resulting from the above shfls) is nice, but not obvious. - // To see why it works, sketch the shfl stages on paper. - memory::store_cs(gmem_out + 64 * i + lane_id, vals[2 * i]); - memory::store_cs(gmem_out + 64 * i + lane_id + 32, vals[2 * i + 1]); + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + // This output pattern (resulting from the above shfls) is nice, but not obvious. + // To see why it works, sketch the shfl stages on paper. + memory::store_cs(gmem_out + 64 * i + lane_id, vals[2 * i]); + memory::store_cs(gmem_out + 64 * i + lane_id + 32, vals[2 * i + 1]); + } } } extern "C" __global__ -void b2n_initial_up_to_8_stages_warp(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, - const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, - const unsigned log_n, const bool inverse, const unsigned blocks_per_ntt, const unsigned log_extension_degree, - const unsigned coset_idx) { +void b2n_initial_8_stages_warp(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, + const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, + const unsigned coset_idx) { b2n_initial_stages_warp<3>(gmem_inputs_matrix, gmem_outputs_matrix, stride_between_input_arrays, stride_between_output_arrays, start_stage, - stages_this_launch, log_n, inverse, blocks_per_ntt, log_extension_degree, coset_idx); + stages_this_launch, log_n, inverse, num_ntts, log_extension_degree, coset_idx); } -template DEVICE_FORCEINLINE +extern "C" __global__ +void b2n_initial_7_stages_warp(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, + const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, + const unsigned coset_idx) { + b2n_initial_stages_warp<2>(gmem_inputs_matrix, gmem_outputs_matrix, stride_between_input_arrays, stride_between_output_arrays, start_stage, + stages_this_launch, log_n, inverse, num_ntts, log_extension_degree, coset_idx); +} + +template DEVICE_FORCEINLINE void b2n_initial_stages_block(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, - const unsigned log_n, const bool inverse, const unsigned blocks_per_ntt, const unsigned log_extension_degree, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, const unsigned coset_idx) { constexpr unsigned VALS_PER_THREAD = 1 << LOG_VALS_PER_THREAD; constexpr unsigned PAIRS_PER_THREAD = VALS_PER_THREAD >> 1; constexpr unsigned VALS_PER_WARP = 32 * VALS_PER_THREAD; constexpr unsigned WARPS_PER_BLOCK = VALS_PER_WARP >> 4; - constexpr unsigned VALS_PER_BLOCK = 32 * VALS_PER_THREAD * WARPS_PER_BLOCK; + constexpr unsigned LOG_VALS_PER_BLOCK = 2 * (LOG_VALS_PER_THREAD + 5) - 4; + constexpr unsigned VALS_PER_BLOCK = 1 << LOG_VALS_PER_BLOCK; __shared__ base_field smem[VALS_PER_BLOCK]; const unsigned lane_id{threadIdx.x & 31}; const unsigned warp_id{threadIdx.x >> 5}; - const unsigned ntt_idx = 0; // blockIdx.x / blocks_per_ntt; - const unsigned block_idx_in_ntt = blockIdx.x - ntt_idx * blocks_per_ntt; - const unsigned gmem_offset = ntt_idx * stride_between_input_arrays + VALS_PER_BLOCK * block_idx_in_ntt + VALS_PER_WARP * warp_id; - const base_field *gmem_in = gmem_inputs_matrix + gmem_offset; - base_field *gmem_out = gmem_outputs_matrix + gmem_offset; + const unsigned gmem_block_offset = VALS_PER_BLOCK * blockIdx.x; + const unsigned gmem_offset = gmem_block_offset + VALS_PER_WARP * warp_id; + const base_field *gmem_in = gmem_inputs_matrix + gmem_offset + + NTTS_PER_BLOCK * stride_between_input_arrays * blockIdx.y; + // annoyingly scrambled, but should be coalesced overall + const unsigned gmem_out_thread_offset = 16 * warp_id + VALS_PER_WARP * (lane_id >> 4) + 2 * (lane_id & 7) + ((lane_id >> 3) & 1); + base_field *gmem_out = gmem_outputs_matrix + gmem_block_offset + gmem_out_thread_offset + + NTTS_PER_BLOCK * stride_between_output_arrays * blockIdx.y; auto twiddle_cache = smem + VALS_PER_WARP * warp_id; base_field vals[VALS_PER_THREAD]; -#pragma unroll - for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { - const auto pair = memory::load_cs(reinterpret_cast(gmem_in + 64 * i + 2 * lane_id)); - vals[2 * i][0] = pair.x; - vals[2 * i][1] = pair.y; - vals[2 * i + 1][0] = pair.z; - vals[2 * i + 1][1] = pair.w; - } + load_initial_twiddles_warp(twiddle_cache, lane_id, gmem_offset, inverse); - // cooperatively loads all the twiddles this warp needs - base_field *twiddles_this_stage = twiddle_cache; - unsigned num_twiddles_this_stage = VALS_PER_WARP >> 1; - unsigned exchg_region_offset = gmem_offset >> 1; - for (unsigned stage = 0; stage < 6 + LOG_VALS_PER_THREAD - 1; stage++) { + const unsigned bound = std::min(NTTS_PER_BLOCK, num_ntts - NTTS_PER_BLOCK * blockIdx.y); + for (unsigned ntt_idx = 0; ntt_idx < bound; + ntt_idx++, gmem_in += stride_between_input_arrays, gmem_out += stride_between_output_arrays) { #pragma unroll - for (unsigned i = lane_id; i < num_twiddles_this_stage; i += 32) { - twiddles_this_stage[i] = get_twiddle(inverse, i + exchg_region_offset); + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + const auto pair = memory::load_cs(reinterpret_cast(gmem_in + 64 * i + 2 * lane_id)); + vals[2 * i][0] = pair.x; + vals[2 * i][1] = pair.y; + vals[2 * i + 1][0] = pair.z; + vals[2 * i + 1][1] = pair.w; } - twiddles_this_stage += num_twiddles_this_stage; - num_twiddles_this_stage >>= 1; - exchg_region_offset >>= 1; - } - - __syncwarp(); - unsigned lane_mask = 1; - twiddles_this_stage = twiddle_cache; - num_twiddles_this_stage = VALS_PER_WARP >> 1; - for (unsigned stage = 0; stage < 6; stage++) { +// if (log_extension_degree && !inverse) { +// if (coset_idx) { +// const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; +// const unsigned offset = coset_idx << shift; +// #pragma unroll +// for (unsigned i = 0; i < VALS_PER_THREAD; i++) { +// const unsigned idx = __brev(gmem_offset + 64 * (i >> 1) + 2 * lane_id + (i & 1)) >> (32 - log_n); +// auto power_of_w = get_power_of_w(idx * offset, false); +// vals[i] = base_field::mul(vals[i], power_of_w); +// } +// } +// #pragma unroll +// for (unsigned i = 0; i < VALS_PER_THREAD; i++) { +// const unsigned idx = __brev(gmem_offset + 64 * (i >> 1) + 2 * lane_id + (i & 1)) >> (32 - log_n); +// auto power_of_g = get_power_of_g(idx, false); +// vals[i] = base_field::mul(vals[i], power_of_g); +// } +// } + + unsigned lane_mask = 1; + base_field *twiddles_this_stage = twiddle_cache; + unsigned num_twiddles_this_stage = VALS_PER_WARP >> 1; + for (unsigned stage = 0; stage < 6; stage++) { #pragma unroll - for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { - const auto twiddle = twiddles_this_stage[(32 * i + lane_id) >> stage]; - exchg_dif(vals[2 * i], vals[2 * i + 1], twiddle); - shfl_xor_bf(vals, i, lane_id, lane_mask); + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + const auto twiddle = twiddles_this_stage[(32 * i + lane_id) >> stage]; + exchg_dif(vals[2 * i], vals[2 * i + 1], twiddle); + if (stage < 5) + shfl_xor_bf(vals, i, lane_id, lane_mask); + } + lane_mask <<= 1; + twiddles_this_stage += num_twiddles_this_stage; + num_twiddles_this_stage >>= 1; } - lane_mask <<= 1; - twiddles_this_stage += num_twiddles_this_stage; - num_twiddles_this_stage >>= 1; - } + for (unsigned i = 1; i < LOG_VALS_PER_THREAD; i++) { #pragma unroll - for (unsigned i = 1, stage = 6; i < LOG_VALS_PER_THREAD; i++, stage++) { -#pragma unroll - for (unsigned j = 0; j < PAIRS_PER_THREAD >> i; j++) { - const unsigned exchg_tile_sz = 2 << i; - const unsigned half_exchg_tile_sz = 1 << i; - const auto twiddle = twiddles_this_stage[j]; + for (unsigned j = 0; j < PAIRS_PER_THREAD >> i; j++) { + const unsigned exchg_tile_sz = 2 << i; + const unsigned half_exchg_tile_sz = 1 << i; + const auto twiddle = twiddles_this_stage[j]; #pragma unroll - for (unsigned k = 0; k < half_exchg_tile_sz; k++) - exchg_dif(vals[exchg_tile_sz * j + k], vals[exchg_tile_sz * j + k + half_exchg_tile_sz], twiddle); + for (unsigned k = 0; k < half_exchg_tile_sz; k++) + exchg_dif(vals[exchg_tile_sz * j + k], vals[exchg_tile_sz * j + k + half_exchg_tile_sz], twiddle); + } + twiddles_this_stage += num_twiddles_this_stage; + num_twiddles_this_stage >>= 1; } - twiddles_this_stage += num_twiddles_this_stage; - num_twiddles_this_stage >>= 1; - } - __syncwarp(); + __syncwarp(); + if (ntt_idx < num_ntts - 1) { #pragma unroll - for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { - // The output pattern (resulting from the above shfls) is nice, but not obvious. - // To see why it works, sketch the shfl stages on paper. - // TODO: Stash twiddles in registers while using smem to communicate data values - //if (ntt_idx != num_ntts - 1) { - twiddle_cache[64 * i + lane_id] = vals[2 * i]; - twiddle_cache[64 * i + lane_id + 32] = vals[2 * i + 1]; - } + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + // juggle twiddles in registers while we use smem to communicate values + const auto tmp0 = twiddle_cache[64 * i + lane_id]; + const auto tmp1 = twiddle_cache[64 * i + lane_id + 32]; + twiddle_cache[64 * i + lane_id] = vals[2 * i]; + twiddle_cache[64 * i + lane_id + 32] = vals[2 * i + 1]; + vals[2 * i] = tmp0; + vals[2 * i + 1] = tmp1; + } + } else { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + twiddle_cache[64 * i + lane_id] = vals[2 * i]; + twiddle_cache[64 * i + lane_id + 32] = vals[2 * i + 1]; + } + } - __syncthreads(); + __syncthreads(); - auto pair_addr = smem + 16 * warp_id + VALS_PER_WARP * (lane_id >> 3) + 2 * (threadIdx.x & 7); + auto pair_addr = smem + 16 * warp_id + VALS_PER_WARP * (lane_id >> 3) + 2 * (threadIdx.x & 7); + if (ntt_idx < num_ntts - 1) { + // juggle twiddles back into smem + // In theory, we could avoid the full-size stashing and extra syncthreads by + // "switching" each warp's twiddle region from contiguous to strided-chunks each iteration, + // but that's a lot of trouble. Let's try the simple approach first. + base_field tmp[VALS_PER_THREAD]; #pragma unroll - for (unsigned i = 0; i < PAIRS_PER_THREAD; i++, pair_addr += 4 * VALS_PER_WARP) { - // TODO: Juggle twiddles here as needed - const auto pair = *reinterpret_cast(pair_addr); - vals[2 * i][0] = pair.x; - vals[2 * i][1] = pair.y; - vals[2 * i + 1][0] = pair.z; - vals[2 * i + 1][1] = pair.w; - } + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + tmp[2 * i] = vals[2 * i]; + tmp[2 * i + 1] = vals[2 * i + 1]; + } - // if (ntt_idx != num_ntts - 1) - // __syncthreads(); +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++, pair_addr += 4 * VALS_PER_WARP) { + const auto pair = *reinterpret_cast(pair_addr); + vals[2 * i][0] = pair.x; + vals[2 * i][1] = pair.y; + vals[2 * i + 1][0] = pair.z; + vals[2 * i + 1][1] = pair.w; + } + + __syncthreads(); - lane_mask = 8; - exchg_region_offset = ((blockIdx.x * WARPS_PER_BLOCK) >> 1) + (lane_id & 16); - unsigned first_interwarp_stage = 6 + LOG_VALS_PER_THREAD - 1; - for (unsigned s = 0; s < first_interwarp_stage + LOG_INTERWARP_STAGES; s++) { #pragma unroll - for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { - // TODO: Handle these cooperatively - const auto twiddle = get_twiddle(inverse, (exchg_region_offset + 2 * i) >> s); - shfl_xor_bf(vals, i, lane_id, lane_mask); - exchg_dif(vals[2 * i], vals[2 * i + 1], twiddle); + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + twiddle_cache[64 * i + lane_id] = tmp[2 * i]; + twiddle_cache[64 * i + lane_id + 32] = tmp[2 * i + 1]; + } + + __syncwarp(); // maybe unnecessary due to shfls below + // __syncthreads(); + } else { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++, pair_addr += 4 * VALS_PER_WARP) { + const auto pair = *reinterpret_cast(pair_addr); + vals[2 * i][0] = pair.x; + vals[2 * i][1] = pair.y; + vals[2 * i + 1][0] = pair.z; + vals[2 * i + 1][1] = pair.w; + } } - lane_mask <<= 1; - } + const unsigned stages_so_far = 6 + LOG_VALS_PER_THREAD - 1; + lane_mask = 8; + unsigned exchg_region_offset = blockIdx.x * (WARPS_PER_BLOCK >> 1) + (lane_id >> 4); + for (unsigned s = 0; s < 2; s++) { + if (s + stages_so_far < stages_this_launch) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + // TODO: Handle these cooperatively? + const auto twiddle = get_twiddle(inverse, exchg_region_offset + ((2 * i) >> s)); + shfl_xor_bf(vals, i, lane_id, lane_mask); + exchg_dif(vals[2 * i], vals[2 * i + 1], twiddle); + } + } else { #pragma unroll - for (unsigned i = 1, stage = 6; i < LOG_VALS_PER_THREAD; i++, stage++) { + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) + shfl_xor_bf(vals, i, lane_id, lane_mask); + } + lane_mask <<= 1; + exchg_region_offset >>= 1; + } + + exchg_region_offset = blockIdx.x * (PAIRS_PER_THREAD >> 1); + for (unsigned i = 1; i < LOG_VALS_PER_THREAD; i++) { + if (i + 2 + stages_so_far <= stages_this_launch) { #pragma unroll - for (unsigned j = 0; j < PAIRS_PER_THREAD >> i; j++) { - const unsigned exchg_tile_sz = 2 << i; - const unsigned half_exchg_tile_sz = 1 << i; - const auto twiddle = twiddles_this_stage[j]; + for (unsigned j = 0; j < PAIRS_PER_THREAD >> i; j++) { + const unsigned exchg_tile_sz = 2 << i; + const unsigned half_exchg_tile_sz = 1 << i; + const auto twiddle = get_twiddle(inverse, exchg_region_offset + (j >> (i - 1))); #pragma unroll - for (unsigned k = 0; k < half_exchg_tile_sz; k++) - exchg_dif(vals[exchg_tile_sz * j + k], vals[exchg_tile_sz * j + k + half_exchg_tile_sz], twiddle); + for (unsigned k = 0; k < half_exchg_tile_sz; k++) + exchg_dif(vals[exchg_tile_sz * j + k], vals[exchg_tile_sz * j + k + half_exchg_tile_sz], twiddle); + } + } + exchg_region_offset >>= 1; } - } #pragma unroll - for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { - // This output pattern (resulting from the above shfls) is nice, but not obvious. - // To see why it works, sketch the shfl stages on paper. - memory::store_cs(gmem_out + 64 * i + lane_id, vals[2 * i]); - memory::store_cs(gmem_out + 64 * i + lane_id + 32, vals[2 * i + 1]); + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + memory::store_cs(gmem_out + 4 * i * VALS_PER_WARP, vals[2 * i]); + memory::store_cs(gmem_out + (4 * i + 2) * VALS_PER_WARP, vals[2 * i + 1]); + } } } +// extern "C" __launch_bounds__(512, 2) __global__ extern "C" __global__ -void b2n_initial_up_to_12_stages_block(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, +void b2n_initial_9_to_12_stages_block(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, - const unsigned log_n, const bool inverse, const unsigned blocks_per_ntt, const unsigned log_extension_degree, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, const unsigned coset_idx) { - b2n_initial_stages_block<3, 4>(gmem_inputs_matrix, gmem_outputs_matrix, stride_between_input_arrays, stride_between_output_arrays, start_stage, - stages_this_launch, log_n, inverse, blocks_per_ntt, log_extension_degree, coset_idx); -} - -extern "C" __global__ -void b2n_noninitial_up_to_8_stages_block(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, - const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, - const unsigned log_n, const bool inverse, const unsigned blocks_per_ntt, const unsigned log_extension_degree, - const unsigned coset_idx) { - // b2n_noninitial_stages_block<3>(gmem_inputs_matrix, gmem_outputs_matrix, stride_between_input_arrays, stride_between_output_arrays, start_stage, - // stages_this_launch, log_n, inverse, blocks_per_ntt, log_extension_degree, coset_idx); + b2n_initial_stages_block<3>(gmem_inputs_matrix, gmem_outputs_matrix, stride_between_input_arrays, stride_between_output_arrays, start_stage, + stages_this_launch, log_n, inverse, num_ntts, log_extension_degree, coset_idx); } -// I bet there are ways to write these macros more concisely, -// but the structure is fairly readable and easy to edit. -#define THREE_REGISTER_STAGES_B2N(SKIP_FIRST) \ - { \ - if (!(SKIP_FIRST)) { \ - /* first stage of this set-of-3 stages */ \ - const auto t3 = get_twiddle(inverse, thread_exchg_region); \ - const auto t4 = get_twiddle(inverse, thread_exchg_region + 1); \ - const auto t5 = get_twiddle(inverse, thread_exchg_region + 2); \ - const auto t6 = get_twiddle(inverse, thread_exchg_region + 3); \ - exchg_dif(reg_vals[0], reg_vals[1], t3); \ - exchg_dif(reg_vals[2], reg_vals[3], t4); \ - exchg_dif(reg_vals[4], reg_vals[5], t5); \ - exchg_dif(reg_vals[6], reg_vals[7], t6); \ - } \ - /* second stage of this set-of-3 stages */ \ - thread_exchg_region >>= 1; \ - const auto t1 = get_twiddle(inverse, thread_exchg_region); \ - const auto t2 = get_twiddle(inverse, thread_exchg_region + 1); \ - exchg_dif(reg_vals[0], reg_vals[2], t1); \ - exchg_dif(reg_vals[1], reg_vals[3], t1); \ - exchg_dif(reg_vals[4], reg_vals[6], t2); \ - exchg_dif(reg_vals[5], reg_vals[7], t2); \ - /* third stage of this set-of-3 stages */ \ - thread_exchg_region >>= 1; \ - const auto t0 = get_twiddle(inverse, thread_exchg_region); \ - exchg_dif(reg_vals[0], reg_vals[4], t0); \ - exchg_dif(reg_vals[1], reg_vals[5], t0); \ - exchg_dif(reg_vals[2], reg_vals[6], t0); \ - exchg_dif(reg_vals[3], reg_vals[7], t0); \ - } +template DEVICE_FORCEINLINE +void b2n_noninitial_stages_block(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, + const unsigned stride_between_output_arrays, const unsigned start_stage, const bool skip_first_stage, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, + const unsigned coset_idx) { + constexpr unsigned VALS_PER_THREAD = 1 << LOG_VALS_PER_THREAD; + constexpr unsigned PAIRS_PER_THREAD = VALS_PER_THREAD >> 1; + constexpr unsigned VALS_PER_WARP = 32 * VALS_PER_THREAD; + constexpr unsigned TILES_PER_WARP = VALS_PER_WARP >> 4; + constexpr unsigned WARPS_PER_BLOCK = VALS_PER_WARP >> 4; + constexpr unsigned LOG_VALS_PER_BLOCK = 2 * (LOG_VALS_PER_THREAD + 5) - 4; + constexpr unsigned VALS_PER_BLOCK = 1 << LOG_VALS_PER_BLOCK; + constexpr unsigned TILES_PER_BLOCK = VALS_PER_BLOCK >> 4; + constexpr unsigned EXCHG_REGIONS_PER_BLOCK = TILES_PER_BLOCK >> 1; + constexpr unsigned MAX_STAGES_THIS_LAUNCH = LOG_VALS_PER_BLOCK - 4; -#define TWO_REGISTER_STAGES_B2N(SKIP_SECOND) \ - { \ - unsigned tmp_thread_exchg_region = thread_exchg_region; \ - /* first stage of this set-of-2 stages */ \ - const auto t1 = get_twiddle(inverse, tmp_thread_exchg_region); \ - const auto t2 = get_twiddle(inverse, tmp_thread_exchg_region + 1); \ - exchg_dif(reg_vals[0], reg_vals[1], t1); \ - exchg_dif(reg_vals[2], reg_vals[3], t2); \ - if (!(SKIP_SECOND)) { \ - /* second stage of this set-of-2 stages */ \ - tmp_thread_exchg_region >>= 1; \ - const auto t0 = get_twiddle(inverse, tmp_thread_exchg_region); \ - exchg_dif(reg_vals[0], reg_vals[2], t0); \ - exchg_dif(reg_vals[1], reg_vals[3], t0); \ - } \ - } + __shared__ base_field smem[VALS_PER_BLOCK]; -#define ONE_EXTRA_STAGE_B2N \ - { \ - const unsigned intrablock_exchg_region = (warp_id >> 1); \ - const unsigned smem_logical_offset = (warp_id & 1) * 32 + lane_id; \ - const unsigned offset_padded = intrablock_exchg_region * 2 * PADDED_WARP_SCRATCH_SIZE; \ - for (int i = 0; i < 4; i++) { \ - const unsigned idx = offset_padded + PAD(smem_logical_offset + i * 64); \ - reg_vals[i] = smem[idx]; \ - reg_vals[i + 4] = smem[idx + PADDED_WARP_SCRATCH_SIZE]; \ - } \ - const auto t0 = get_twiddle(inverse, 8 * block_idx_in_ntt + intrablock_exchg_region); \ - exchg_dif(reg_vals[0], reg_vals[4], t0); \ - exchg_dif(reg_vals[1], reg_vals[5], t0); \ - exchg_dif(reg_vals[2], reg_vals[6], t0); \ - exchg_dif(reg_vals[3], reg_vals[7], t0); \ - const unsigned offset = intrablock_exchg_region * 512 + smem_logical_offset; \ - for (int i = 0; i < 4; i++) { \ - memory::store_cs(gmem_output + offset + i * 64, reg_vals[i]); \ - memory::store_cs(gmem_output + offset + i * 64 + 256, reg_vals[i + 4]); \ - } \ - return; \ - } + const unsigned lane_id{threadIdx.x & 31}; + const unsigned warp_id{threadIdx.x >> 5}; + const unsigned log_tile_stride = skip_first_stage ? start_stage - 1 : start_stage; + const unsigned tile_stride = 1 << log_tile_stride; + const unsigned log_blocks_per_region = log_tile_stride - 4; // tile size is always 16 + const unsigned block_bfly_region_size = TILES_PER_BLOCK * tile_stride; + const unsigned block_bfly_region = blockIdx.x >> log_blocks_per_region; + const unsigned block_exchg_region_offset = block_bfly_region * EXCHG_REGIONS_PER_BLOCK; + const unsigned block_bfly_region_start = block_bfly_region * block_bfly_region_size; + const unsigned block_start_in_bfly_region = 16 * (blockIdx.x & ((1 << log_blocks_per_region) - 1)); + const base_field *gmem_in = gmem_inputs_matrix + block_bfly_region_start + block_start_in_bfly_region + + NTTS_PER_BLOCK * stride_between_input_arrays * blockIdx.y; + // annoyingly scrambled, but should be coalesced overall + const unsigned gmem_out_thread_offset = tile_stride * warp_id + tile_stride * WARPS_PER_BLOCK * (lane_id >> 4) + + 2 * (lane_id & 7) + ((lane_id >> 3) & 1); + const unsigned gmem_out_offset = block_bfly_region_start + block_start_in_bfly_region + gmem_out_thread_offset; + base_field *gmem_out = gmem_outputs_matrix + gmem_out_offset + NTTS_PER_BLOCK * stride_between_output_arrays * blockIdx.y; -#define TWO_EXTRA_STAGES_B2N \ - { \ - const unsigned intrablock_exchg_region = (warp_id >> 2); \ - const unsigned smem_logical_offset = (warp_id & 3) * 32 + lane_id; \ - const unsigned offset_padded = intrablock_exchg_region * 4 * PADDED_WARP_SCRATCH_SIZE; \ - for (int i = 0; i < 2; i++) { \ - for (int j = 0; j < 2; j++) { \ - const unsigned idx = offset_padded + PAD(smem_logical_offset + j * 128) + i * PADDED_WARP_SCRATCH_SIZE; \ - reg_vals[2 * i + j] = smem[idx]; \ - reg_vals[2 * i + j + 4] = smem[idx + 2 * PADDED_WARP_SCRATCH_SIZE]; \ - } \ - } \ - unsigned global_exchg_region = 8 * block_idx_in_ntt + intrablock_exchg_region * 2; \ - const auto t1 = get_twiddle(inverse, global_exchg_region); \ - const auto t2 = get_twiddle(inverse, global_exchg_region + 1); \ - global_exchg_region >>= 1; \ - const auto t0 = get_twiddle(inverse, global_exchg_region); \ - exchg_dif(reg_vals[0], reg_vals[2], t1); \ - exchg_dif(reg_vals[1], reg_vals[3], t1); \ - exchg_dif(reg_vals[4], reg_vals[6], t2); \ - exchg_dif(reg_vals[5], reg_vals[7], t2); \ - exchg_dif(reg_vals[0], reg_vals[4], t0); \ - exchg_dif(reg_vals[1], reg_vals[5], t0); \ - exchg_dif(reg_vals[2], reg_vals[6], t0); \ - exchg_dif(reg_vals[3], reg_vals[7], t0); \ - const unsigned offset = intrablock_exchg_region * 1024 + smem_logical_offset; \ - for (int i = 0; i < 4; i++) { \ - memory::store_cs(gmem_output + offset + i * 128, reg_vals[i]); \ - memory::store_cs(gmem_output + offset + i * 128 + 512, reg_vals[i + 4]); \ - } \ - return; \ - } + auto twiddle_cache = smem + VALS_PER_WARP * warp_id; -#define THREE_OR_FOUR_EXTRA_STAGES_B2N(THREE_STAGES) \ - { \ - const unsigned intrablock_exchg_region = (warp_id >> 3); \ - const unsigned smem_logical_offset = (warp_id & 7) * 32 + lane_id; \ - const unsigned offset_padded = intrablock_exchg_region * 8 * PADDED_WARP_SCRATCH_SIZE + PAD(smem_logical_offset); \ - for (int i = 0; i < 4; i++) { \ - const unsigned idx = offset_padded + i * PADDED_WARP_SCRATCH_SIZE; \ - reg_vals[i] = smem[idx]; \ - reg_vals[i + 4] = smem[idx + 4 * PADDED_WARP_SCRATCH_SIZE]; \ - } \ - unsigned global_exchg_region = 8 * block_idx_in_ntt + intrablock_exchg_region * 4; \ - const auto t3 = get_twiddle(inverse, global_exchg_region); \ - const auto t4 = get_twiddle(inverse, global_exchg_region + 1); \ - const auto t5 = get_twiddle(inverse, global_exchg_region + 2); \ - const auto t6 = get_twiddle(inverse, global_exchg_region + 3); \ - global_exchg_region >>= 1; \ - const auto t1 = get_twiddle(inverse, global_exchg_region); \ - const auto t2 = get_twiddle(inverse, global_exchg_region + 1); \ - global_exchg_region >>= 1; \ - const auto t0 = get_twiddle(inverse, global_exchg_region); \ - exchg_dif(reg_vals[0], reg_vals[1], t3); \ - exchg_dif(reg_vals[2], reg_vals[3], t4); \ - exchg_dif(reg_vals[4], reg_vals[5], t5); \ - exchg_dif(reg_vals[6], reg_vals[7], t6); \ - exchg_dif(reg_vals[0], reg_vals[2], t1); \ - exchg_dif(reg_vals[1], reg_vals[3], t1); \ - exchg_dif(reg_vals[4], reg_vals[6], t2); \ - exchg_dif(reg_vals[5], reg_vals[7], t2); \ - exchg_dif(reg_vals[0], reg_vals[4], t0); \ - exchg_dif(reg_vals[1], reg_vals[5], t0); \ - exchg_dif(reg_vals[2], reg_vals[6], t0); \ - exchg_dif(reg_vals[3], reg_vals[7], t0); \ - if ((THREE_STAGES)) { \ - const unsigned offset = intrablock_exchg_region * 2048 + smem_logical_offset; \ - for (int i = 0; i < 4; i++) { \ - memory::store_cs(gmem_output + offset + i * 256, reg_vals[i]); \ - memory::store_cs(gmem_output + offset + i * 256 + 1024, reg_vals[i + 4]); \ - } \ - } else { \ - for (int i = 0; i < 4; i++) { \ - const unsigned idx = offset_padded + i * PADDED_WARP_SCRATCH_SIZE; \ - smem[idx] = reg_vals[i]; \ - smem[idx + 4 * PADDED_WARP_SCRATCH_SIZE] = reg_vals[i + 4]; \ - } \ - /* in theory, we could avoid full __syncthreads by splitting each warp into two half-warps of size 16, */ \ - /* assigning first-halves to first 2048 elems and second-halves to second 2048 elems, then */ \ - /* combining results from first and second halves with intrawarp syncs, but that doesn't seem worth the trouble */ \ - __syncthreads(); \ - const auto t0 = get_twiddle(inverse, block_idx_in_ntt); \ - int i = threadIdx.x; \ - int i_padded = (threadIdx.x >> 8) * PADDED_WARP_SCRATCH_SIZE + PAD(threadIdx.x & 255); \ - for (; i < 2048; i += 512, i_padded += 2 * PADDED_WARP_SCRATCH_SIZE) { \ - reg_vals[0] = smem[i_padded]; \ - reg_vals[1] = smem[i_padded + 8 * PADDED_WARP_SCRATCH_SIZE]; \ - exchg_dif(reg_vals[0], reg_vals[1], t0); \ - memory::store_cs(gmem_output + i, reg_vals[0]); \ - memory::store_cs(gmem_output + i + 2048, reg_vals[1]); \ - } \ - } \ - return; \ - } + base_field vals[VALS_PER_THREAD]; -extern "C" __launch_bounds__(512, 2) __global__ - void b2n_initial_7_or_8_stages(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, - const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, - const unsigned log_n, const bool inverse, const unsigned blocks_per_ntt, const unsigned log_extension_degree, - const unsigned coset_idx) { - extern __shared__ base_field smem[]; // 4096 elems + load_noninitial_twiddles_warp(twiddle_cache, lane_id, warp_id, block_exchg_region_offset, inverse); - const unsigned tile_stride{16}; - const unsigned lane_in_tile = threadIdx.x & 15; - const unsigned lane_id{threadIdx.x & 31}; - const unsigned warp_id{threadIdx.x >> 5}; - const unsigned ntt_idx = blockIdx.x / blocks_per_ntt; - const unsigned block_idx_in_ntt = blockIdx.x - ntt_idx * blocks_per_ntt; - const base_field *gmem_input = gmem_inputs_matrix + ntt_idx * stride_between_input_arrays + 4096 * block_idx_in_ntt; - base_field *gmem_output = gmem_outputs_matrix + ntt_idx * stride_between_output_arrays + 4096 * block_idx_in_ntt; - - { - // maybe some memcpy_asyncs could micro-optimize this - // maybe an arrive-wait barrier could micro-optimize the start_stage > 0 case - base_field reg_vals[8]; -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned g = tile * tile_stride + lane_in_tile; - reg_vals[i] = memory::load_cs(gmem_input + g); + const unsigned bound = std::min(NTTS_PER_BLOCK, num_ntts - NTTS_PER_BLOCK * blockIdx.y); + for (unsigned ntt_idx = 0; ntt_idx < bound; + ntt_idx++, gmem_in += stride_between_input_arrays, gmem_out += stride_between_output_arrays) { + if (skip_first_stage) { + auto val0_addr = gmem_in + TILES_PER_WARP * tile_stride * warp_id + 2 * tile_stride * (lane_id >> 4) + 2 * (threadIdx.x & 7) + (lane_id >> 3 & 1); +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + vals[2 * i] = memory::load_cs(val0_addr); + vals[2 * i + 1] = memory::load_cs(val0_addr + tile_stride); + val0_addr += 4 * tile_stride; + } + } else { + auto pair_addr = gmem_in + TILES_PER_WARP * tile_stride * warp_id + tile_stride * (lane_id >> 3) + 2 * (threadIdx.x & 7); +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + const auto pair = memory::load_cs(reinterpret_cast(pair_addr)); + vals[2 * i][0] = pair.x; + vals[2 * i][1] = pair.y; + vals[2 * i + 1][0] = pair.z; + vals[2 * i + 1][1] = pair.w; + pair_addr += 4 * tile_stride; + } } - if (log_extension_degree && !inverse) { - const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; - const unsigned offset = coset_idx << shift; -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned idx = __brev(tile * tile_stride + lane_in_tile + 4096 * block_idx_in_ntt) >> (32 - log_n); - if (coset_idx) { - auto power_of_w = get_power_of_w(idx * offset, false); - reg_vals[i] = base_field::mul(reg_vals[i], power_of_w); + + unsigned lane_mask = 8; + base_field *twiddles_this_stage = twiddle_cache; + unsigned num_twiddles_this_stage = 1 << LOG_VALS_PER_THREAD; + for (unsigned s = 4; s < LOG_VALS_PER_THREAD + 3; s++) { + if (!skip_first_stage || s > 4) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + const auto twiddle = twiddles_this_stage[(32 * i + lane_id) >> s]; + shfl_xor_bf(vals, i, lane_id, lane_mask); + exchg_dif(vals[2 * i], vals[2 * i + 1], twiddle); } - auto power_of_g = get_power_of_g(idx, false); - reg_vals[i] = base_field::mul(reg_vals[i], power_of_g); } + lane_mask <<= 1; + twiddles_this_stage += num_twiddles_this_stage; + num_twiddles_this_stage >>= 1; } -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - // puts each warp's data in its assigned smem region - const unsigned s = (t >> 8) * PADDED_WARP_SCRATCH_SIZE + PAD(t & 255); - smem[s] = reg_vals[i]; - } - } - __syncwarp(); - - unsigned warp_exchg_region = block_idx_in_ntt * 2048 + warp_id * 128; - base_field reg_vals[8]; - base_field *warp_scratch = smem + PADDED_WARP_SCRATCH_SIZE * warp_id; - - unsigned thread_exchg_region = warp_exchg_region + lane_id * 4; - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(lane_id * 8 + i)]; - THREE_REGISTER_STAGES_B2N(false) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(lane_id * 8 + i)] = reg_vals[i]; - warp_exchg_region >>= 3; - - __syncwarp(); - - thread_exchg_region = warp_exchg_region + (lane_id >> 3) * 4; - const unsigned vals_start = 64 * (lane_id >> 3) + (lane_id & 7); - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(vals_start + 8 * i)]; - THREE_REGISTER_STAGES_B2N(false) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(vals_start + 8 * i)] = reg_vals[i]; - warp_exchg_region >>= 3; - - __syncwarp(); - - thread_exchg_region = warp_exchg_region; - for (int j = 0; j < 64; j += 32) { - for (int i = 0; i < 4; i++) - reg_vals[i] = warp_scratch[PAD(lane_id + 64 * i + j)]; - TWO_REGISTER_STAGES_B2N((stages_this_launch == 7)) - for (int i = 0; i < 4; i++) - warp_scratch[PAD(lane_id + 64 * i + j)] = reg_vals[i]; - } + for (unsigned i = 1; i < LOG_VALS_PER_THREAD; i++) { +#pragma unroll + for (unsigned j = 0; j < PAIRS_PER_THREAD >> i; j++) { + const unsigned exchg_tile_sz = 2 << i; + const unsigned half_exchg_tile_sz = 1 << i; + const auto twiddle = twiddles_this_stage[j]; +#pragma unroll + for (unsigned k = 0; k < half_exchg_tile_sz; k++) + exchg_dif(vals[exchg_tile_sz * j + k], vals[exchg_tile_sz * j + k + half_exchg_tile_sz], twiddle); + } + twiddles_this_stage += num_twiddles_this_stage; + num_twiddles_this_stage >>= 1; + } - __syncwarp(); + __syncwarp(); -// unroll 2 and unroll 4 give comparable perf. Stores don't incur a stall so it's not as important to ILP them. -#pragma unroll 1 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned s = (t >> 8) * PADDED_WARP_SCRATCH_SIZE + PAD(t & 255); - const auto val = smem[s]; - const unsigned g = tile * tile_stride + lane_in_tile; - memory::store_cs(gmem_output + g, val); - } -} + // there are at most 31 per-warp twiddles, so we only need 1 temporary per thread to stash them + base_field tmp{}; + if (ntt_idx < num_ntts - 1) + tmp = twiddle_cache[lane_id]; -extern "C" __launch_bounds__(512, 2) __global__ - void b2n_initial_9_to_12_stages(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, - const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, - const unsigned log_n, const bool inverse, const unsigned blocks_per_ntt, const unsigned log_extension_degree, - const unsigned coset_idx) { - extern __shared__ base_field smem[]; // 4096 elems - - const unsigned tile_stride{16}; - const unsigned lane_in_tile = threadIdx.x & 15; - const unsigned lane_id{threadIdx.x & 31}; - const unsigned warp_id{threadIdx.x >> 5}; - const unsigned ntt_idx = blockIdx.x / blocks_per_ntt; - const unsigned block_idx_in_ntt = blockIdx.x - ntt_idx * blocks_per_ntt; - const base_field *gmem_input = gmem_inputs_matrix + ntt_idx * stride_between_input_arrays + 4096 * block_idx_in_ntt; - base_field *gmem_output = gmem_outputs_matrix + ntt_idx * stride_between_output_arrays + 4096 * block_idx_in_ntt; - - { - // maybe some memcpy_asyncs could further micro-optimize this - // maybe an arrive-wait barrier could further micro-optimize the start_stage > 0 case - base_field reg_vals[8]; -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned g = tile * tile_stride + lane_in_tile; - reg_vals[i] = memory::load_cs(gmem_input + g); - } - if (log_extension_degree && !inverse) { - const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; - const unsigned offset = coset_idx << shift; -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned idx = __brev(tile * tile_stride + lane_in_tile + 4096 * block_idx_in_ntt) >> (32 - log_n); - if (coset_idx) { - auto power_of_w = get_power_of_w(idx * offset, false); - reg_vals[i] = base_field::mul(reg_vals[i], power_of_w); - } - auto power_of_g = get_power_of_g(idx, false); - reg_vals[i] = base_field::mul(reg_vals[i], power_of_g); - } - } -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - // puts each warp's data in its assigned smem region - const unsigned s = (t >> 8) * PADDED_WARP_SCRATCH_SIZE + PAD(t & 255); - smem[s] = reg_vals[i]; + // annoyingly scrambled but should be bank-conflict-free + const unsigned smem_thread_offset = 16 * (lane_id >> 4) + 2 * (lane_id & 7) + ((lane_id >> 3) & 1); +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + twiddle_cache[64 * i + smem_thread_offset] = vals[2 * i]; + twiddle_cache[64 * i + smem_thread_offset + 32] = vals[2 * i + 1]; } - } - __syncwarp(); - - unsigned warp_exchg_region = block_idx_in_ntt * 2048 + warp_id * 128; - base_field reg_vals[8]; - base_field *warp_scratch = smem + PADDED_WARP_SCRATCH_SIZE * warp_id; - - unsigned thread_exchg_region = warp_exchg_region + lane_id * 4; - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(lane_id * 8 + i)]; - THREE_REGISTER_STAGES_B2N(false) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(lane_id * 8 + i)] = reg_vals[i]; - warp_exchg_region >>= 3; - - __syncwarp(); - - thread_exchg_region = warp_exchg_region + (lane_id >> 3) * 4; - const unsigned vals_start = 64 * (lane_id >> 3) + (lane_id & 7); - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(vals_start + 8 * i)]; - THREE_REGISTER_STAGES_B2N(false) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(vals_start + 8 * i)] = reg_vals[i]; - warp_exchg_region >>= 3; - - __syncwarp(); - - thread_exchg_region = warp_exchg_region; - for (int j = 0; j < 64; j += 32) { - for (int i = 0; i < 4; i++) - reg_vals[i] = warp_scratch[PAD(lane_id + 64 * i + j)]; - TWO_REGISTER_STAGES_B2N(false) - for (int i = 0; i < 4; i++) - warp_scratch[PAD(lane_id + 64 * i + j)] = reg_vals[i]; - } + __syncthreads(); - // This start_stage == 0 kernel can handle up to 11 stages if needed by the overall NTT. - __syncthreads(); - const unsigned stages_remaining = stages_this_launch - 8; - switch (stages_remaining) { - case 1: - ONE_EXTRA_STAGE_B2N - case 2: - TWO_EXTRA_STAGES_B2N - case 3: - THREE_OR_FOUR_EXTRA_STAGES_B2N(true) - case 4: - THREE_OR_FOUR_EXTRA_STAGES_B2N(false) - } -} + auto smem_pair_addr = smem + 16 * warp_id + VALS_PER_WARP * (lane_id >> 3) + 2 * (threadIdx.x & 7); +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++, smem_pair_addr += 4 * VALS_PER_WARP) { + const auto pair = *reinterpret_cast(smem_pair_addr); + vals[2 * i][0] = pair.x; + vals[2 * i][1] = pair.y; + vals[2 * i + 1][0] = pair.z; + vals[2 * i + 1][1] = pair.w; + } -extern "C" __launch_bounds__(512, 2) __global__ - void b2n_noninitial_7_or_8_stages(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, - const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, - const unsigned log_n, const bool inverse, const unsigned blocks_per_ntt, const unsigned log_extension_degree, - const unsigned coset_idx) { - extern __shared__ base_field smem[]; // 4096 elems + if (ntt_idx < num_ntts - 1) { + __syncthreads(); + twiddle_cache[lane_id] = tmp; + __syncwarp(); // maybe unnecessary due to shfls below + // __syncthreads(); + } - // If we're only doing 7 stages, skip one stage by loading tiles with half the stride, - // such that the first exchange (between nearest-neighbor tiles) has already happened - const unsigned log_stride{(stages_this_launch == 7) ? start_stage - 1 : start_stage}; - const unsigned tile_stride{1u << log_stride}; - const unsigned lane_in_tile{threadIdx.x & 15}; - const unsigned lane_id{threadIdx.x & 31}; - const unsigned warp_id{threadIdx.x >> 5}; - const unsigned exchg_region_sz{tile_stride << 8}; - const unsigned log_blocks_per_region{log_stride - 4}; // tile_stride / 16 - const unsigned ntt_idx{blockIdx.x / blocks_per_ntt}; - const unsigned block_idx_in_ntt{blockIdx.x - ntt_idx * blocks_per_ntt}; - unsigned block_exchg_region{block_idx_in_ntt >> log_blocks_per_region}; - const unsigned block_exchg_region_start{block_exchg_region * exchg_region_sz}; - const unsigned block_start_in_exchg_region{16 * (block_idx_in_ntt & ((1 << log_blocks_per_region) - 1))}; - const base_field *gmem_input = gmem_inputs_matrix + ntt_idx * stride_between_input_arrays + block_exchg_region_start + block_start_in_exchg_region; - base_field *gmem_output = gmem_outputs_matrix + ntt_idx * stride_between_output_arrays + block_exchg_region_start + block_start_in_exchg_region; - - { - // maybe some memcpy_asyncs could further micro-optimize this - // maybe an arrive-wait barrier could further micro-optimize the start_stage > 0 case - base_field reg_vals[8]; -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned g = tile * tile_stride + lane_in_tile; - reg_vals[i] = memory::load_cs(gmem_input + g); + + lane_mask = 8; + unsigned exchg_region_offset = (block_exchg_region_offset >> (LOG_VALS_PER_THREAD + 1)) + (lane_id >> 4); + for (unsigned s = 0; s < 2; s++) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + // TODO: Handle these cooperatively? + const auto twiddle = get_twiddle(inverse, exchg_region_offset + ((2 * i) >> s)); + shfl_xor_bf(vals, i, lane_id, lane_mask); + exchg_dif(vals[2 * i], vals[2 * i + 1], twiddle); + } + lane_mask <<= 1; + exchg_region_offset >>= 1; } -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - // puts each warp's data in its assigned smem region - const unsigned tile = t >> 4; - const unsigned s = lane_in_tile * PADDED_WARP_SCRATCH_SIZE + PAD(tile); - smem[s] = reg_vals[i]; + + for (unsigned i = 1; i < LOG_VALS_PER_THREAD; i++) { +#pragma unroll + for (unsigned j = 0; j < PAIRS_PER_THREAD >> i; j++) { + const unsigned exchg_tile_sz = 2 << i; + const unsigned half_exchg_tile_sz = 1 << i; + const auto twiddle = get_twiddle(inverse, exchg_region_offset + (j >> (i - 1))); +#pragma unroll + for (unsigned k = 0; k < half_exchg_tile_sz; k++) + exchg_dif(vals[exchg_tile_sz * j + k], vals[exchg_tile_sz * j + k + half_exchg_tile_sz], twiddle); + } + exchg_region_offset >>= 1; } - } - __syncthreads(); - - base_field reg_vals[8]; - base_field *warp_scratch = smem + PADDED_WARP_SCRATCH_SIZE * warp_id; - - block_exchg_region *= 128; - unsigned thread_exchg_region = block_exchg_region + lane_id * 4; - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(lane_id * 8 + i)]; - THREE_REGISTER_STAGES_B2N((stages_this_launch == 7)) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(lane_id * 8 + i)] = reg_vals[i]; - block_exchg_region >>= 3; - - __syncwarp(); - - thread_exchg_region = block_exchg_region + (lane_id >> 3) * 4; - const unsigned vals_start = 64 * (lane_id >> 3) + (lane_id & 7); - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(vals_start + 8 * i)]; - THREE_REGISTER_STAGES_B2N(false) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(vals_start + 8 * i)] = reg_vals[i]; - block_exchg_region >>= 3; - - __syncwarp(); - - thread_exchg_region = block_exchg_region; - for (int j = 0; j < 64; j += 32) { - for (int i = 0; i < 4; i++) - reg_vals[i] = warp_scratch[PAD(lane_id + 64 * i + j)]; - TWO_REGISTER_STAGES_B2N(false) - for (int i = 0; i < 4; i++) - warp_scratch[PAD(lane_id + 64 * i + j)] = reg_vals[i]; - } + if (inverse && (start_stage + MAX_STAGES_THIS_LAUNCH - skip_first_stage == log_n)) { +#pragma unroll + for (unsigned i = 0; i < VALS_PER_THREAD; i++) + vals[i] = base_field::mul(vals[i], inv_sizes[log_n]); +// if (log_extension_degree) { +// if (coset_idx) { +// const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; +// const unsigned offset = coset_idx << shift; +// #pragma unroll +// for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { +// const unsigned idx0 = gmem_out_offset + 4 * i * tile_stride * WARPS_PER_BLOCK; +// const unsigned idx1 = gmem_out_offset + (4 * i + 2) * tile_stride * WARPS_PER_BLOCK; +// auto power_of_w0 = get_power_of_w(idx0 * offset, true); +// auto power_of_w1 = get_power_of_w(idx1 * offset, true); +// vals[2 * i] = base_field::mul(vals[2 * i], power_of_w0); +// vals[2 * i + 1] = base_field::mul(vals[2 * i + 1], power_of_w1); +// } +// } +// #pragma unroll +// for (unsigned i = 0; i < VALS_PER_THREAD; i++) { +// const unsigned idx0 = gmem_out_offset + 4 * i * tile_stride * WARPS_PER_BLOCK; +// const unsigned idx1 = gmem_out_offset + (4 * i + 2) * tile_stride * WARPS_PER_BLOCK; +// auto power_of_g0 = get_power_of_g(idx0, true); +// auto power_of_g1 = get_power_of_g(idx1, true); +// vals[2 * i] = base_field::mul(vals[2 * i], power_of_g0); +// vals[2 * i + 1] = base_field::mul(vals[2 * i + 1], power_of_g1); +// } +// } + } - __syncthreads(); - -// unroll 2 and unroll 4 give comparable perf. Stores don't incur a stall so it's not as important to ILP them. -#pragma unroll 1 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned s = lane_in_tile * PADDED_WARP_SCRATCH_SIZE + PAD(tile); - auto val = smem[s]; - const unsigned g = tile * tile_stride + lane_in_tile; - if (inverse && (start_stage + stages_this_launch == log_n)) { - val = base_field::mul(val, inv_sizes[log_n]); - if (log_extension_degree) { - const unsigned idx = g + block_exchg_region_start + block_start_in_exchg_region; - if (coset_idx) { - const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; - const unsigned offset = coset_idx << shift; - auto power_of_w = get_power_of_w(idx * offset, true); - val = base_field::mul(val, power_of_w); - } - auto power_of_g = get_power_of_g(idx, true); - val = base_field::mul(val, power_of_g); - } +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + memory::store_cs(gmem_out + 4 * i * tile_stride * WARPS_PER_BLOCK, vals[2 * i]); + memory::store_cs(gmem_out + (4 * i + 2) * tile_stride * WARPS_PER_BLOCK, vals[2 * i + 1]); } - memory::store_cs(gmem_output + g, val); } } +extern "C" __launch_bounds__(512, 2) __global__ +void b2n_noninitial_7_or_8_stages_block(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, + const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, + const unsigned coset_idx) { + b2n_noninitial_stages_block<3>(gmem_inputs_matrix, gmem_outputs_matrix, stride_between_input_arrays, stride_between_output_arrays, start_stage, + stages_this_launch == 7, log_n, inverse, num_ntts, log_extension_degree, coset_idx); +} + // Simple, non-optimized kernel used for log_n < 16, to unblock debugging small proofs. extern "C" __launch_bounds__(512, 2) __global__ void b2n_1_stage(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, diff --git a/boojum-cuda/native/ntt_n2b.cuh b/boojum-cuda/native/ntt_n2b.cuh index 3202b6d..9ffbcf0 100644 --- a/boojum-cuda/native/ntt_n2b.cuh +++ b/boojum-cuda/native/ntt_n2b.cuh @@ -1,462 +1,523 @@ #pragma once // also, this file should only be compiled in one compile unit because it has __global__ definitions -#define THREE_REGISTER_STAGES_N2B(SKIP_THIRD) \ - { \ - /* first stage of this set-of-3 stages */ \ - const auto t0 = get_twiddle(inverse, thread_exchg_region); \ - exchg_dit(reg_vals[0], reg_vals[4], t0); \ - exchg_dit(reg_vals[1], reg_vals[5], t0); \ - exchg_dit(reg_vals[2], reg_vals[6], t0); \ - exchg_dit(reg_vals[3], reg_vals[7], t0); \ - /* second stage of this set-of-3 stages */ \ - thread_exchg_region *= 2; \ - const auto t1 = get_twiddle(inverse, thread_exchg_region); \ - const auto t2 = get_twiddle(inverse, thread_exchg_region + 1); \ - exchg_dit(reg_vals[0], reg_vals[2], t1); \ - exchg_dit(reg_vals[1], reg_vals[3], t1); \ - exchg_dit(reg_vals[4], reg_vals[6], t2); \ - exchg_dit(reg_vals[5], reg_vals[7], t2); \ - if (!(SKIP_THIRD)) { \ - /* third stage of this set-of-3 stages */ \ - thread_exchg_region *= 2; \ - const auto t3 = get_twiddle(inverse, thread_exchg_region); \ - const auto t4 = get_twiddle(inverse, thread_exchg_region + 1); \ - const auto t5 = get_twiddle(inverse, thread_exchg_region + 2); \ - const auto t6 = get_twiddle(inverse, thread_exchg_region + 3); \ - exchg_dit(reg_vals[0], reg_vals[1], t3); \ - exchg_dit(reg_vals[2], reg_vals[3], t4); \ - exchg_dit(reg_vals[4], reg_vals[5], t5); \ - exchg_dit(reg_vals[6], reg_vals[7], t6); \ - } \ - } +// This kernel basically reverses the pattern of the b2n_initial_stages_warp kernel. +template DEVICE_FORCEINLINE +void n2b_final_stages_warp(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, + const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, + const unsigned coset_idx) { + constexpr unsigned VALS_PER_THREAD = 1 << LOG_VALS_PER_THREAD; + constexpr unsigned PAIRS_PER_THREAD = VALS_PER_THREAD >> 1; + constexpr unsigned VALS_PER_WARP = 32 * VALS_PER_THREAD; + constexpr unsigned LOG_VALS_PER_BLOCK = 5 + LOG_VALS_PER_THREAD + 2; + constexpr unsigned VALS_PER_BLOCK = 1 << LOG_VALS_PER_BLOCK; + + __shared__ base_field smem[VALS_PER_BLOCK]; -#define TWO_REGISTER_STAGES_N2B(SKIP_FIRST) \ - { \ - unsigned tmp_thread_exchg_region = thread_exchg_region; \ - if (!(SKIP_FIRST)) { \ - /* first stage of this set-of-2 stages */ \ - const auto t0 = get_twiddle(inverse, tmp_thread_exchg_region); \ - exchg_dit(reg_vals[0], reg_vals[2], t0); \ - exchg_dit(reg_vals[1], reg_vals[3], t0); \ - } \ - /* second stage of this set-of-2 stages */ \ - tmp_thread_exchg_region *= 2; \ - const auto t1 = get_twiddle(inverse, tmp_thread_exchg_region); \ - const auto t2 = get_twiddle(inverse, tmp_thread_exchg_region + 1); \ - exchg_dit(reg_vals[0], reg_vals[1], t1); \ - exchg_dit(reg_vals[2], reg_vals[3], t2); \ - } + const unsigned lane_id{threadIdx.x & 31}; + const unsigned warp_id{threadIdx.x >> 5}; + const unsigned gmem_offset = VALS_PER_BLOCK * blockIdx.x + VALS_PER_WARP * warp_id; + const base_field *gmem_in = gmem_inputs_matrix + gmem_offset + NTTS_PER_BLOCK * stride_between_input_arrays * blockIdx.y; + base_field *gmem_out = gmem_outputs_matrix + gmem_offset + NTTS_PER_BLOCK * stride_between_output_arrays * blockIdx.y; -#define ONE_EXTRA_STAGE_N2B \ - { \ - const unsigned intrablock_exchg_region = (warp_id >> 1); \ - const unsigned smem_logical_offset = (warp_id & 1) * 32 + lane_id; \ - const unsigned offset = intrablock_exchg_region * 512 + smem_logical_offset; \ - for (int i = 0; i < 4; i++) { \ - reg_vals[i] = memory::load_cs(gmem_input + offset + i * 64); \ - reg_vals[i + 4] = memory::load_cs(gmem_input + offset + i * 64 + 256); \ - } \ - const auto t0 = get_twiddle(inverse, 8 * block_idx_in_ntt + intrablock_exchg_region); \ - exchg_dit(reg_vals[0], reg_vals[4], t0); \ - exchg_dit(reg_vals[1], reg_vals[5], t0); \ - exchg_dit(reg_vals[2], reg_vals[6], t0); \ - exchg_dit(reg_vals[3], reg_vals[7], t0); \ - const unsigned offset_padded = intrablock_exchg_region * 2 * PADDED_WARP_SCRATCH_SIZE; \ - for (int i = 0; i < 4; i++) { \ - const unsigned idx = offset_padded + PAD(smem_logical_offset + i * 64); \ - smem[idx] = reg_vals[i]; \ - smem[idx + PADDED_WARP_SCRATCH_SIZE] = reg_vals[i + 4]; \ - } \ - } + auto twiddle_cache = smem + VALS_PER_WARP * warp_id; -#define TWO_EXTRA_STAGES_N2B \ - { \ - const unsigned intrablock_exchg_region = (warp_id >> 2); \ - const unsigned smem_logical_offset = (warp_id & 3) * 32 + lane_id; \ - const unsigned offset = intrablock_exchg_region * 1024 + smem_logical_offset; \ - for (int i = 0; i < 4; i++) { \ - reg_vals[i] = memory::load_cs(gmem_input + offset + i * 128); \ - reg_vals[i + 4] = memory::load_cs(gmem_input + offset + i * 128 + 512); \ - } \ - unsigned global_exchg_region = 4 * block_idx_in_ntt + intrablock_exchg_region; \ - const auto t0 = get_twiddle(inverse, global_exchg_region); \ - global_exchg_region *= 2; \ - const auto t1 = get_twiddle(inverse, global_exchg_region); \ - const auto t2 = get_twiddle(inverse, global_exchg_region + 1); \ - exchg_dit(reg_vals[0], reg_vals[4], t0); \ - exchg_dit(reg_vals[1], reg_vals[5], t0); \ - exchg_dit(reg_vals[2], reg_vals[6], t0); \ - exchg_dit(reg_vals[3], reg_vals[7], t0); \ - exchg_dit(reg_vals[0], reg_vals[2], t1); \ - exchg_dit(reg_vals[1], reg_vals[3], t1); \ - exchg_dit(reg_vals[4], reg_vals[6], t2); \ - exchg_dit(reg_vals[5], reg_vals[7], t2); \ - const unsigned offset_padded = intrablock_exchg_region * 4 * PADDED_WARP_SCRATCH_SIZE; \ - for (int i = 0; i < 2; i++) { \ - for (int j = 0; j < 2; j++) { \ - const unsigned idx = offset_padded + PAD(smem_logical_offset + j * 128) + i * PADDED_WARP_SCRATCH_SIZE; \ - smem[idx] = reg_vals[2 * i + j]; \ - smem[idx + 2 * PADDED_WARP_SCRATCH_SIZE] = reg_vals[2 * i + j + 4]; \ - } \ - } \ - } + base_field vals[VALS_PER_THREAD]; -#define THREE_OR_FOUR_EXTRA_STAGES_N2B(THREE_STAGES) \ - { \ - const unsigned intrablock_exchg_region = (warp_id >> 3); \ - const unsigned smem_logical_offset = (warp_id & 7) * 32 + lane_id; \ - const unsigned offset_padded = intrablock_exchg_region * 8 * PADDED_WARP_SCRATCH_SIZE + PAD(smem_logical_offset); \ - if ((THREE_STAGES)) { \ - const unsigned offset = intrablock_exchg_region * 2048 + smem_logical_offset; \ - for (int i = 0; i < 4; i++) { \ - reg_vals[i] = memory::load_cs(gmem_input + offset + i * 256); \ - reg_vals[i + 4] = memory::load_cs(gmem_input + offset + i * 256 + 1024); \ - } \ - } else { \ - const auto t0 = get_twiddle(inverse, block_idx_in_ntt); \ - int i = threadIdx.x; \ - int i_padded = (threadIdx.x >> 8) * PADDED_WARP_SCRATCH_SIZE + PAD(threadIdx.x & 255); \ - for (; i < 2048; i += 512, i_padded += 2 * PADDED_WARP_SCRATCH_SIZE) { \ - reg_vals[0] = memory::load_cs(gmem_output + i); \ - reg_vals[1] = memory::load_cs(gmem_output + i + 2048); \ - exchg_dit(reg_vals[0], reg_vals[1], t0); \ - smem[i_padded] = reg_vals[0]; \ - smem[i_padded + 8 * PADDED_WARP_SCRATCH_SIZE] = reg_vals[1]; \ - } \ - /* in theory it's possible to avoid full __syncthreads() here, see THREE_OR_FOUR_EXTRA_STAGES_B2N */ \ - __syncthreads(); \ - for (int i = 0; i < 4; i++) { \ - const unsigned idx = offset_padded + i * PADDED_WARP_SCRATCH_SIZE; \ - reg_vals[i] = smem[idx]; \ - reg_vals[i + 4] = smem[idx + 4 * PADDED_WARP_SCRATCH_SIZE]; \ - } \ - } \ - unsigned global_exchg_region = 2 * block_idx_in_ntt + intrablock_exchg_region; \ - const auto t0 = get_twiddle(inverse, global_exchg_region); \ - global_exchg_region *= 2; \ - const auto t1 = get_twiddle(inverse, global_exchg_region); \ - const auto t2 = get_twiddle(inverse, global_exchg_region + 1); \ - global_exchg_region *= 2; \ - const auto t3 = get_twiddle(inverse, global_exchg_region); \ - const auto t4 = get_twiddle(inverse, global_exchg_region + 1); \ - const auto t5 = get_twiddle(inverse, global_exchg_region + 2); \ - const auto t6 = get_twiddle(inverse, global_exchg_region + 3); \ - exchg_dit(reg_vals[0], reg_vals[4], t0); \ - exchg_dit(reg_vals[1], reg_vals[5], t0); \ - exchg_dit(reg_vals[2], reg_vals[6], t0); \ - exchg_dit(reg_vals[3], reg_vals[7], t0); \ - exchg_dit(reg_vals[0], reg_vals[2], t1); \ - exchg_dit(reg_vals[1], reg_vals[3], t1); \ - exchg_dit(reg_vals[4], reg_vals[6], t2); \ - exchg_dit(reg_vals[5], reg_vals[7], t2); \ - exchg_dit(reg_vals[0], reg_vals[1], t3); \ - exchg_dit(reg_vals[2], reg_vals[3], t4); \ - exchg_dit(reg_vals[4], reg_vals[5], t5); \ - exchg_dit(reg_vals[6], reg_vals[7], t6); \ - for (int i = 0; i < 4; i++) { \ - const unsigned idx = offset_padded + i * PADDED_WARP_SCRATCH_SIZE; \ - smem[idx] = reg_vals[i]; \ - smem[idx + 4 * PADDED_WARP_SCRATCH_SIZE] = reg_vals[i + 4]; \ - } \ - } + load_initial_twiddles_warp(twiddle_cache, lane_id, gmem_offset, inverse); -extern "C" __launch_bounds__(512, 2) __global__ - void n2b_final_7_or_8_stages(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, - const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, - const unsigned log_n, const bool inverse, const unsigned blocks_per_ntt, const unsigned log_extension_degree, - const unsigned coset_idx) { - extern __shared__ base_field smem[]; // 4096 elems - - const unsigned tile_stride{16}; - const unsigned lane_in_tile = threadIdx.x & 15; - const unsigned lane_id{threadIdx.x & 31}; - const unsigned warp_id{threadIdx.x >> 5}; - const unsigned ntt_idx = blockIdx.x / blocks_per_ntt; - const unsigned block_idx_in_ntt = blockIdx.x - ntt_idx * blocks_per_ntt; - const base_field *gmem_input = gmem_inputs_matrix + ntt_idx * stride_between_input_arrays + 4096 * block_idx_in_ntt; - base_field *gmem_output = gmem_outputs_matrix + ntt_idx * stride_between_output_arrays + 4096 * block_idx_in_ntt; - - { - // maybe some memcpy_asyncs could micro-optimize this - // maybe an arrive-wait barrier could micro-optimize the start_stage > 0 case - base_field reg_vals[8]; -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned g = tile * tile_stride + lane_in_tile; - reg_vals[i] = memory::load_cs(gmem_input + g); - } -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - // puts each warp's data in its assigned smem region - const unsigned s = (t >> 8) * PADDED_WARP_SCRATCH_SIZE + PAD(t & 255); - memory::store(smem + s, reg_vals[i]); + const unsigned bound = std::min(NTTS_PER_BLOCK, num_ntts - NTTS_PER_BLOCK * blockIdx.y); + for (unsigned ntt_idx = 0; ntt_idx < bound; + ntt_idx++, gmem_in += stride_between_input_arrays, gmem_out += stride_between_output_arrays) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + vals[2 * i] = memory::load_cs(gmem_in + 64 * i + lane_id); + vals[2 * i + 1] = memory::load_cs(gmem_in + 64 * i + lane_id + 32); } - } - __syncwarp(); + base_field *twiddles_this_stage = twiddle_cache + VALS_PER_WARP - 2; + unsigned num_twiddles_this_stage = 1; + for (unsigned i = 0; i < LOG_VALS_PER_THREAD - 1; i++) { +#pragma unroll + for (unsigned j = 0; j < (1 << i); j++) { + const unsigned exchg_tile_sz = VALS_PER_THREAD >> i; + const unsigned half_exchg_tile_sz = exchg_tile_sz >> 1; + const auto twiddle = twiddles_this_stage[j]; +#pragma unroll + for (unsigned k = 0; k < half_exchg_tile_sz; k++) { + exchg_dit(vals[exchg_tile_sz * j + k], vals[exchg_tile_sz * j + k + half_exchg_tile_sz], twiddle); + } + } + num_twiddles_this_stage <<= 1; + twiddles_this_stage -= num_twiddles_this_stage; + } - unsigned warp_exchg_region = block_idx_in_ntt * 16 + warp_id; - base_field reg_vals[8]; - base_field *warp_scratch = smem + PADDED_WARP_SCRATCH_SIZE * warp_id; + unsigned lane_mask = 16; + for (unsigned stage = 0, s = 5; stage < 6; stage++, s--) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + const auto twiddle = twiddles_this_stage[(32 * i + lane_id) >> s]; + exchg_dit(vals[2 * i], vals[2 * i + 1], twiddle); + if (stage < 5) + shfl_xor_bf(vals, i, lane_id, lane_mask); + } + lane_mask >>= 1; + num_twiddles_this_stage <<= 1; + twiddles_this_stage -= num_twiddles_this_stage; + } - unsigned thread_exchg_region = warp_exchg_region; - for (int j = 0; j < 64; j += 32) { - for (int i = 0; i < 4; i++) - reg_vals[i] = warp_scratch[PAD(lane_id + 64 * i + j)]; - TWO_REGISTER_STAGES_N2B((stages_this_launch == 7)) - for (int i = 0; i < 4; i++) - warp_scratch[PAD(lane_id + 64 * i + j)] = reg_vals[i]; - } - warp_exchg_region *= 4; - - __syncwarp(); - - thread_exchg_region = warp_exchg_region + (lane_id >> 3); - const unsigned vals_start = 64 * (lane_id >> 3) + (lane_id & 7); - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(vals_start + 8 * i)]; - THREE_REGISTER_STAGES_N2B(false) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(vals_start + 8 * i)] = reg_vals[i]; - warp_exchg_region *= 8; - - __syncwarp(); - - thread_exchg_region = warp_exchg_region + lane_id; - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(lane_id * 8 + i)]; - THREE_REGISTER_STAGES_N2B(false) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(lane_id * 8 + i)] = reg_vals[i]; - - __syncwarp(); - -// unroll 2 and unroll 4 give comparable perf. Stores don't incur a stall so it's not as important to ILP them. -#pragma unroll 1 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned s = (t >> 8) * PADDED_WARP_SCRATCH_SIZE + PAD(t & 255); - auto val = smem[s]; - const unsigned g = tile * tile_stride + lane_in_tile; if (inverse) { - val = base_field::mul(val, inv_sizes[log_n]); - if (log_extension_degree) { - const unsigned idx = __brev(g + 4096 * block_idx_in_ntt) >> (32 - log_n); - if (coset_idx) { - const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; - const unsigned offset = coset_idx << shift; - auto power_of_w = get_power_of_w(idx * offset, true); - val = base_field::mul(val, power_of_w); - } - auto power_of_g = get_power_of_g(idx, true); - val = base_field::mul(val, power_of_g); - } +#pragma unroll + for (unsigned i = 0; i < VALS_PER_THREAD; i++) + vals[i] = base_field::mul(vals[i], inv_sizes[log_n]); +// if (log_extension_degree) { +// if (coset_idx) { +// const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; +// const unsigned offset = coset_idx << shift; +// #pragma unroll +// for (unsigned i = 0; i < VALS_PER_THREAD; i++) { +// const unsigned idx = __brev(gmem_offset + 64 * (i >> 1) + 2 * lane_id + (i & 1)) >> (32 - log_n); +// auto power_of_w = get_power_of_w(idx * offset, true); +// vals[i] = base_field::mul(vals[i], power_of_w); +// } +// } +// #pragma unroll +// for (unsigned i = 0; i < VALS_PER_THREAD; i++) { +// const unsigned idx = __brev(gmem_offset + 64 * (i >> 1) + 2 * lane_id + (i & 1)) >> (32 - log_n); +// auto power_of_g = get_power_of_g(idx, true); +// vals[i] = base_field::mul(vals[i], power_of_g); +// } +// } + } + +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + const uint4 out{vals[2 * i][0], vals[2 * i][1], vals[2 * i + 1][0], vals[2 * i + 1][1]}; + memory::store_cs(reinterpret_cast(gmem_out + 64 * i + 2 * lane_id), out); } - memory::store_cs(gmem_output + g, val); } } -extern "C" __launch_bounds__(512, 2) __global__ - void n2b_final_9_to_12_stages(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, - const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, - const unsigned log_n, const bool inverse, const unsigned blocks_per_ntt, const unsigned log_extension_degree, - const unsigned coset_idx) { - extern __shared__ base_field smem[]; // 4096 elems - - const unsigned tile_stride{16}; - const unsigned lane_in_tile = threadIdx.x & 15; +// extern "C" __launch_bounds__(128, 8) __global__ +extern "C" __global__ +void n2b_final_8_stages_warp(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, + const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, + const unsigned coset_idx) { + n2b_final_stages_warp<3>(gmem_inputs_matrix, gmem_outputs_matrix, stride_between_input_arrays, stride_between_output_arrays, start_stage, + stages_this_launch, log_n, inverse, num_ntts, log_extension_degree, coset_idx); +} + +extern "C" __global__ +void n2b_final_7_stages_warp(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, + const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, + const unsigned coset_idx) { + n2b_final_stages_warp<2>(gmem_inputs_matrix, gmem_outputs_matrix, stride_between_input_arrays, stride_between_output_arrays, start_stage, + stages_this_launch, log_n, inverse, num_ntts, log_extension_degree, coset_idx); +} + +// This kernel basically reverses the pattern of the b2n_initial_stages_block kernel. +template DEVICE_FORCEINLINE +void n2b_final_stages_block(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, + const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, + const unsigned coset_idx) { + constexpr unsigned VALS_PER_THREAD = 1 << LOG_VALS_PER_THREAD; + constexpr unsigned PAIRS_PER_THREAD = VALS_PER_THREAD >> 1; + constexpr unsigned VALS_PER_WARP = 32 * VALS_PER_THREAD; + constexpr unsigned WARPS_PER_BLOCK = VALS_PER_WARP >> 4; + constexpr unsigned VALS_PER_BLOCK = 32 * VALS_PER_THREAD * WARPS_PER_BLOCK; + constexpr unsigned MAX_STAGES_THIS_LAUNCH = 2 * (LOG_VALS_PER_THREAD + 5) - 4; + + __shared__ base_field smem[VALS_PER_BLOCK]; + const unsigned lane_id{threadIdx.x & 31}; const unsigned warp_id{threadIdx.x >> 5}; - const unsigned ntt_idx = blockIdx.x / blocks_per_ntt; - const unsigned block_idx_in_ntt = blockIdx.x - ntt_idx * blocks_per_ntt; - const base_field *gmem_input = gmem_inputs_matrix + ntt_idx * stride_between_input_arrays + 4096 * block_idx_in_ntt; - base_field *gmem_output = gmem_outputs_matrix + ntt_idx * stride_between_output_arrays + 4096 * block_idx_in_ntt; - - base_field reg_vals[8]; - base_field *warp_scratch = smem + PADDED_WARP_SCRATCH_SIZE * warp_id; - - // This kernel can handle up to 11 stages if needed by the overall NTT. - const unsigned extra_stages = stages_this_launch - 8; - switch (extra_stages) { - case 1: - ONE_EXTRA_STAGE_N2B - break; - case 2: - TWO_EXTRA_STAGES_N2B - break; - case 3: - THREE_OR_FOUR_EXTRA_STAGES_N2B(true) - break; - case 4: - THREE_OR_FOUR_EXTRA_STAGES_N2B(false) - break; - } + const unsigned gmem_block_offset = VALS_PER_BLOCK * blockIdx.x; + const unsigned gmem_offset = gmem_block_offset + VALS_PER_WARP * warp_id; + // annoyingly scrambled, but should be coalesced overall + const unsigned gmem_in_thread_offset = 16 * warp_id + VALS_PER_WARP * (lane_id >> 4) + 2 * (lane_id & 7) + ((lane_id >> 3) & 1); + const base_field *gmem_in = gmem_inputs_matrix + gmem_block_offset + gmem_in_thread_offset + + NTTS_PER_BLOCK * stride_between_input_arrays * blockIdx.y; + base_field *gmem_out = gmem_outputs_matrix + gmem_offset + NTTS_PER_BLOCK * stride_between_output_arrays * blockIdx.y; + + auto twiddle_cache = smem + VALS_PER_WARP * warp_id; + + base_field vals[VALS_PER_THREAD]; + + const unsigned bound = std::min(NTTS_PER_BLOCK, num_ntts - NTTS_PER_BLOCK * blockIdx.y); + for (unsigned ntt_idx = 0; ntt_idx < bound; + ntt_idx++, gmem_in += stride_between_input_arrays, gmem_out += stride_between_output_arrays) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + vals[2 * i] = memory::load_cs(gmem_in + 4 * i * VALS_PER_WARP); + vals[2 * i + 1] = memory::load_cs(gmem_in + (4 * i + 2) * VALS_PER_WARP); + } - __syncthreads(); + const unsigned stages_to_skip = MAX_STAGES_THIS_LAUNCH - stages_this_launch; + unsigned exchg_region_offset = blockIdx.x; + for (unsigned i = 0; i < LOG_VALS_PER_THREAD - 1; i++) { + if (i >= stages_to_skip) { +#pragma unroll + for (unsigned j = 0; j < (1 << i); j++) { + const unsigned exchg_tile_sz = VALS_PER_THREAD >> i; + const unsigned half_exchg_tile_sz = exchg_tile_sz >> 1; + const auto twiddle = get_twiddle(inverse, exchg_region_offset + j); +#pragma unroll + for (unsigned k = 0; k < half_exchg_tile_sz; k++) + exchg_dit(vals[exchg_tile_sz * j + k], vals[exchg_tile_sz * j + k + half_exchg_tile_sz], twiddle); + } + } + exchg_region_offset <<= 1; + } - unsigned warp_exchg_region = block_idx_in_ntt * 16 + warp_id; - unsigned thread_exchg_region = warp_exchg_region; - for (int j = 0; j < 64; j += 32) { - for (int i = 0; i < 4; i++) - reg_vals[i] = warp_scratch[PAD(lane_id + 64 * i + j)]; - TWO_REGISTER_STAGES_N2B(false) - for (int i = 0; i < 4; i++) - warp_scratch[PAD(lane_id + 64 * i + j)] = reg_vals[i]; - } - warp_exchg_region *= 4; - - __syncwarp(); - - thread_exchg_region = warp_exchg_region + (lane_id >> 3); - const unsigned vals_start = 64 * (lane_id >> 3) + (lane_id & 7); - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(vals_start + 8 * i)]; - THREE_REGISTER_STAGES_N2B(false) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(vals_start + 8 * i)] = reg_vals[i]; - warp_exchg_region *= 8; - - __syncwarp(); - - thread_exchg_region = warp_exchg_region + lane_id; - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(lane_id * 8 + i)]; - THREE_REGISTER_STAGES_N2B(false) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(lane_id * 8 + i)] = reg_vals[i]; - - __syncwarp(); - -// unroll 2 and unroll 4 give comparable perf. Stores don't incur a stall so it's not as important to ILP them. -#pragma unroll 1 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned s = (t >> 8) * PADDED_WARP_SCRATCH_SIZE + PAD(t & 255); - auto val = smem[s]; - const unsigned g = tile * tile_stride + lane_in_tile; - if (inverse) { - val = base_field::mul(val, inv_sizes[log_n]); - if (log_extension_degree) { - const unsigned idx = __brev(g + 4096 * block_idx_in_ntt) >> (32 - log_n); - if (coset_idx) { - const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; - const unsigned offset = coset_idx << shift; - auto power_of_w = get_power_of_w(idx * offset, true); - val = base_field::mul(val, power_of_w); + unsigned lane_mask = 16; + unsigned halfwarp_id = lane_id >> 4; + for (unsigned s = 0; s < 2; s++) { + if ((s + LOG_VALS_PER_THREAD - 1) >= stages_to_skip) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + // TODO: Handle these cooperatively? + const auto twiddle = get_twiddle(inverse, exchg_region_offset + ((2 * i + halfwarp_id) >> (1 - s))); + exchg_dit(vals[2 * i], vals[2 * i + 1], twiddle); + shfl_xor_bf(vals, i, lane_id, lane_mask); + } + } else { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) + shfl_xor_bf(vals, i, lane_id, lane_mask); + } + lane_mask >>= 1; + exchg_region_offset <<= 1; + } + + __syncwarp(); // maybe unnecessary but can't hurt + + { + base_field tmp[VALS_PER_THREAD]; + auto pair_addr = smem + 16 * warp_id + VALS_PER_WARP * (lane_id >> 3) + 2 * (threadIdx.x & 7); + if (ntt_idx > 0) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + tmp[2 * i] = twiddle_cache[64 * i + lane_id]; + tmp[2 * i + 1] = twiddle_cache[64 * i + lane_id + 32]; + } + + __syncthreads(); + +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++, pair_addr += 4 * VALS_PER_WARP) { + uint4* pair = reinterpret_cast(pair_addr); + const uint4 out{vals[2 * i][0], vals[2 * i][1], vals[2 * i + 1][0], vals[2 * i + 1][1]}; + *pair = out; + } + } else { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++, pair_addr += 4 * VALS_PER_WARP) { + uint4* pair = reinterpret_cast(pair_addr); + const uint4 out{vals[2 * i][0], vals[2 * i][1], vals[2 * i + 1][0], vals[2 * i + 1][1]}; + *pair = out; + } + } + + __syncthreads(); + + if (ntt_idx > 0) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + vals[2 * i] = twiddle_cache[64 * i + lane_id]; + vals[2 * i + 1] = twiddle_cache[64 * i + lane_id + 32]; + twiddle_cache[64 * i + lane_id] = tmp[2 * i]; + twiddle_cache[64 * i + lane_id + 32] = tmp[2 * i + 1]; } - auto power_of_g = get_power_of_g(idx, true); - val = base_field::mul(val, power_of_g); + + __syncwarp(); + } else { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + vals[2 * i] = twiddle_cache[64 * i + lane_id]; + vals[2 * i + 1] = twiddle_cache[64 * i + lane_id + 32]; + } + + __syncwarp(); + + load_initial_twiddles_warp(twiddle_cache, lane_id, gmem_offset, inverse); } } - memory::store_cs(gmem_output + g, val); + + base_field *twiddles_this_stage = twiddle_cache + VALS_PER_WARP - 2; + unsigned num_twiddles_this_stage = 1; + for (unsigned i = 0; i < LOG_VALS_PER_THREAD - 1; i++) { +#pragma unroll + for (unsigned j = 0; j < (1 << i); j++) { + const unsigned exchg_tile_sz = VALS_PER_THREAD >> i; + const unsigned half_exchg_tile_sz = exchg_tile_sz >> 1; + const auto twiddle = twiddles_this_stage[j]; +#pragma unroll + for (unsigned k = 0; k < half_exchg_tile_sz; k++) { + exchg_dit(vals[exchg_tile_sz * j + k], vals[exchg_tile_sz * j + k + half_exchg_tile_sz], twiddle); + } + } + num_twiddles_this_stage <<= 1; + twiddles_this_stage -= num_twiddles_this_stage; + } + + lane_mask = 16; + for (unsigned stage = 0, s = 5; stage < 6; stage++, s--) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + const auto twiddle = twiddles_this_stage[(32 * i + lane_id) >> s]; + exchg_dit(vals[2 * i], vals[2 * i + 1], twiddle); + if (stage < 5) + shfl_xor_bf(vals, i, lane_id, lane_mask); + } + lane_mask >>= 1; + num_twiddles_this_stage <<= 1; + twiddles_this_stage -= num_twiddles_this_stage; + } + + if (inverse) { +#pragma unroll + for (unsigned i = 0; i < VALS_PER_THREAD; i++) + vals[i] = base_field::mul(vals[i], inv_sizes[log_n]); +// if (log_extension_degree) { +// if (coset_idx) { +// const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; +// const unsigned offset = coset_idx << shift; +// #pragma unroll +// for (unsigned i = 0; i < VALS_PER_THREAD; i++) { +// const unsigned idx = __brev(gmem_offset + 64 * (i >> 1) + 2 * lane_id + (i & 1)) >> (32 - log_n); +// auto power_of_w = get_power_of_w(idx * offset, true); +// vals[i] = base_field::mul(vals[i], power_of_w); +// } +// } +// #pragma unroll +// for (unsigned i = 0; i < VALS_PER_THREAD; i++) { +// const unsigned idx = __brev(gmem_offset + 64 * (i >> 1) + 2 * lane_id + (i & 1)) >> (32 - log_n); +// auto power_of_g = get_power_of_g(idx, true); +// vals[i] = base_field::mul(vals[i], power_of_g); +// } +// } + } + +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + const uint4 out{vals[2 * i][0], vals[2 * i][1], vals[2 * i + 1][0], vals[2 * i + 1][1]}; + memory::store_cs(reinterpret_cast(gmem_out + 64 * i + 2 * lane_id), out); + } } } extern "C" __launch_bounds__(512, 2) __global__ - void n2b_nonfinal_7_or_8_stages(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, +void n2b_final_9_to_12_stages_block(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, - const unsigned log_n, const bool inverse, const unsigned blocks_per_ntt, const unsigned log_extension_degree, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, const unsigned coset_idx) { - extern __shared__ base_field smem[]; // 4096 elems + n2b_final_stages_block<3>(gmem_inputs_matrix, gmem_outputs_matrix, stride_between_input_arrays, stride_between_output_arrays, start_stage, + stages_this_launch, log_n, inverse, num_ntts, log_extension_degree, coset_idx); +} + +// This kernel basically reverses the pattern of the b2n_noninitial_stages_block kernel. +template DEVICE_FORCEINLINE +void n2b_nonfinal_stages_block(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, + const unsigned stride_between_output_arrays, const unsigned start_stage, const bool skip_last_stage, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, + const unsigned coset_idx) { + constexpr unsigned VALS_PER_THREAD = 1 << LOG_VALS_PER_THREAD; + constexpr unsigned PAIRS_PER_THREAD = VALS_PER_THREAD >> 1; + constexpr unsigned VALS_PER_WARP = 32 * VALS_PER_THREAD; + constexpr unsigned TILES_PER_WARP = VALS_PER_WARP >> 4; + constexpr unsigned WARPS_PER_BLOCK = VALS_PER_WARP >> 4; + constexpr unsigned VALS_PER_BLOCK = VALS_PER_WARP * WARPS_PER_BLOCK; + constexpr unsigned TILES_PER_BLOCK = VALS_PER_BLOCK >> 4; + constexpr unsigned EXCHG_REGIONS_PER_BLOCK = TILES_PER_BLOCK >> 1; + constexpr unsigned MAX_STAGES_THIS_LAUNCH = 2 * (LOG_VALS_PER_THREAD + 5) - 8; + + __shared__ base_field smem[VALS_PER_BLOCK]; - const unsigned log_stride{log_n - start_stage - 1}; - const unsigned tile_stride{1u << (log_stride - 7)}; - const unsigned lane_in_tile = threadIdx.x & 15; const unsigned lane_id{threadIdx.x & 31}; const unsigned warp_id{threadIdx.x >> 5}; - const unsigned exchg_region_sz{1u << (log_stride + 1)}; - const unsigned log_blocks_per_region = log_stride - 11; // tile_stride / 16 - const unsigned ntt_idx = blockIdx.x / blocks_per_ntt; - const unsigned block_idx_in_ntt = blockIdx.x - ntt_idx * blocks_per_ntt; - unsigned block_exchg_region = block_idx_in_ntt >> log_blocks_per_region; - const unsigned block_exchg_region_start = block_exchg_region * exchg_region_sz; - const unsigned block_start_in_exchg_region = 16 * (block_idx_in_ntt & ((1 << log_blocks_per_region) - 1)); - const base_field *gmem_input = gmem_inputs_matrix + ntt_idx * stride_between_input_arrays + block_exchg_region_start + block_start_in_exchg_region; - base_field *gmem_output = gmem_outputs_matrix + ntt_idx * stride_between_output_arrays + block_exchg_region_start + block_start_in_exchg_region; - - { - // maybe some memcpy_asyncs could further micro-optimize this - // maybe an arrive-wait barrier could further micro-optimize the start_stage > 0 case - base_field reg_vals[8]; -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned g = tile * tile_stride + lane_in_tile; - reg_vals[i] = memory::load_cs(gmem_input + g); + const unsigned log_tile_stride = log_n - start_stage - MAX_STAGES_THIS_LAUNCH; + const unsigned tile_stride = 1 << log_tile_stride; + const unsigned log_blocks_per_region = log_tile_stride - 4; // tile size is always 16 + const unsigned block_bfly_region_size = TILES_PER_BLOCK * tile_stride; + const unsigned block_bfly_region = blockIdx.x >> log_blocks_per_region; + const unsigned block_bfly_region_start = block_bfly_region * block_bfly_region_size; + const unsigned block_start_in_bfly_region = 16 * (blockIdx.x & ((1 << log_blocks_per_region) - 1)); + // annoyingly scrambled, but should be coalesced overall + const unsigned gmem_in_thread_offset = tile_stride * warp_id + tile_stride * WARPS_PER_BLOCK * (lane_id >> 4) + + 2 * (lane_id & 7) + ((lane_id >> 3) & 1); + const unsigned gmem_in_offset = block_bfly_region_start + block_start_in_bfly_region + gmem_in_thread_offset; + const base_field *gmem_in = gmem_inputs_matrix + gmem_in_offset + NTTS_PER_BLOCK * stride_between_input_arrays * blockIdx.y; + base_field *gmem_out = gmem_outputs_matrix + block_bfly_region_start + block_start_in_bfly_region + + NTTS_PER_BLOCK * stride_between_output_arrays * blockIdx.y; + + auto twiddle_cache = smem + VALS_PER_WARP * warp_id; + + base_field vals[VALS_PER_THREAD]; + + const unsigned bound = std::min(NTTS_PER_BLOCK, num_ntts - NTTS_PER_BLOCK * blockIdx.y); + for (unsigned ntt_idx = 0; ntt_idx < bound; + ntt_idx++, gmem_in += stride_between_input_arrays, gmem_out += stride_between_output_arrays) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + vals[2 * i] = memory::load_cs(gmem_in + 4 * i * tile_stride * WARPS_PER_BLOCK); + vals[2 * i + 1] = memory::load_cs(gmem_in + (4 * i + 2) * tile_stride * WARPS_PER_BLOCK); } - if ((start_stage == 0) && log_extension_degree && !inverse) { - const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; - const unsigned offset = coset_idx << shift; -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned idx = tile * tile_stride + lane_in_tile + block_exchg_region_start + block_start_in_exchg_region; - if (coset_idx) { - auto power_of_w = get_power_of_w(idx * offset, false); - reg_vals[i] = base_field::mul(reg_vals[i], power_of_w); - } - auto power_of_g = get_power_of_g(idx, false); - reg_vals[i] = base_field::mul(reg_vals[i], power_of_g); + +// if ((start_stage == 0) && log_extension_degree && !inverse) { +// if (coset_idx) { +// const unsigned shift = OMEGA_LOG_ORDER - log_n - log_extension_degree; +// const unsigned offset = coset_idx << shift; +// #pragma unroll +// for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { +// const unsigned idx0 = gmem_in_offset + 4 * i * tile_stride * WARPS_PER_BLOCK; +// const unsigned idx1 = gmem_in_offset + (4 * i + 2) * tile_stride * WARPS_PER_BLOCK; +// auto power_of_w0 = get_power_of_w(idx0 * offset, false); +// auto power_of_w1 = get_power_of_w(idx1 * offset, false); +// vals[2 * i] = base_field::mul(vals[2 * i], power_of_w0); +// vals[2 * i + 1] = base_field::mul(vals[2 * i + 1], power_of_w1); +// } +// } +// #pragma unroll +// for (unsigned i = 0; i < VALS_PER_THREAD; i++) { +// const unsigned idx0 = gmem_in_offset + 4 * i * tile_stride * WARPS_PER_BLOCK; +// const unsigned idx1 = gmem_in_offset + (4 * i + 2) * tile_stride * WARPS_PER_BLOCK; +// auto power_of_g0 = get_power_of_g(idx0, false); +// auto power_of_g1 = get_power_of_g(idx1, false); +// vals[2 * i] = base_field::mul(vals[2 * i], power_of_g0); +// vals[2 * i + 1] = base_field::mul(vals[2 * i + 1], power_of_g1); +// } +// } + + unsigned block_exchg_region_offset = block_bfly_region; + for (unsigned i = 0; i < LOG_VALS_PER_THREAD - 1; i++) { +#pragma unroll + for (unsigned j = 0; j < (1 << i); j++) { + const unsigned exchg_tile_sz = VALS_PER_THREAD >> i; + const unsigned half_exchg_tile_sz = exchg_tile_sz >> 1; + const auto twiddle = get_twiddle(inverse, block_exchg_region_offset + j); +#pragma unroll + for (unsigned k = 0; k < half_exchg_tile_sz; k++) + exchg_dit(vals[exchg_tile_sz * j + k], vals[exchg_tile_sz * j + k + half_exchg_tile_sz], twiddle); + } + block_exchg_region_offset <<= 1; + } + + unsigned lane_mask = 16; + unsigned halfwarp_id = lane_id >> 4; + for (unsigned s = 0; s < 2; s++) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + // TODO: Handle these cooperatively? + const auto twiddle = get_twiddle(inverse, block_exchg_region_offset + ((2 * i + halfwarp_id) >> (1 - s))); + exchg_dit(vals[2 * i], vals[2 * i + 1], twiddle); + shfl_xor_bf(vals, i, lane_id, lane_mask); } + lane_mask >>= 1; + block_exchg_region_offset <<= 1; } -#pragma unroll 8 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - // puts each warp's data in its assigned smem region - const unsigned tile = t >> 4; - const unsigned s = lane_in_tile * PADDED_WARP_SCRATCH_SIZE + PAD(tile); - memory::store(smem + s, reg_vals[i]); + + __syncwarp(); // maybe unnecessary but can't hurt + + // there are at most 31 per-warp twiddles, so we only need 1 temporary per thread to stash them + base_field tmp; + if (ntt_idx > 0) { + tmp = twiddle_cache[lane_id]; + __syncthreads(); } - } - __syncthreads(); + auto smem_pair_addr = smem + 16 * warp_id + VALS_PER_WARP * (lane_id >> 3) + 2 * (threadIdx.x & 7); +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++, smem_pair_addr += 4 * VALS_PER_WARP) { + uint4* pair = reinterpret_cast(smem_pair_addr); + const uint4 out{vals[2 * i][0], vals[2 * i][1], vals[2 * i + 1][0], vals[2 * i + 1][1]}; + *pair = out; + } - base_field reg_vals[8]; - base_field *warp_scratch = smem + PADDED_WARP_SCRATCH_SIZE * warp_id; + __syncthreads(); - unsigned thread_exchg_region = block_exchg_region; - for (int j = 0; j < 64; j += 32) { - for (int i = 0; i < 4; i++) - reg_vals[i] = warp_scratch[PAD(lane_id + 64 * i + j)]; - TWO_REGISTER_STAGES_N2B(false) - for (int i = 0; i < 4; i++) - warp_scratch[PAD(lane_id + 64 * i + j)] = reg_vals[i]; - } - block_exchg_region *= 4; - - __syncwarp(); - - thread_exchg_region = block_exchg_region + (lane_id >> 3); - const unsigned vals_start = 64 * (lane_id >> 3) + (lane_id & 7); - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(vals_start + 8 * i)]; - THREE_REGISTER_STAGES_N2B(false) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(vals_start + 8 * i)] = reg_vals[i]; - block_exchg_region *= 8; - - __syncwarp(); - - thread_exchg_region = block_exchg_region + lane_id; - for (int i = 0; i < 8; i++) - reg_vals[i] = warp_scratch[PAD(lane_id * 8 + i)]; - THREE_REGISTER_STAGES_N2B((stages_this_launch == 7)) - for (int i = 0; i < 8; i++) - warp_scratch[PAD(lane_id * 8 + i)] = reg_vals[i]; - - __syncthreads(); - -// unroll 2 and unroll 4 give comparable perf. Stores don't incur a stall so it's not as important to ILP them. -#pragma unroll 1 - for (unsigned i = 0, t = warp_id * 256 + lane_id; i < 8; i++, t += 32) { - const unsigned tile = t >> 4; - const unsigned s = lane_in_tile * PADDED_WARP_SCRATCH_SIZE + PAD(tile); - const auto val = smem[s]; - const unsigned g = tile * tile_stride + lane_in_tile; - memory::store_cs(gmem_output + g, val); + // annoyingly scrambled but should be bank-conflict-free + const unsigned smem_thread_offset = 16 * (lane_id >> 4) + 2 * (lane_id & 7) + ((lane_id >> 3) & 1); +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + vals[2 * i] = twiddle_cache[64 * i + smem_thread_offset]; + vals[2 * i + 1] = twiddle_cache[64 * i + smem_thread_offset + 32]; + } + + __syncwarp(); + + if (ntt_idx > 0) { + twiddle_cache[lane_id] = tmp; + __syncwarp(); + } else { + load_noninitial_twiddles_warp(twiddle_cache, lane_id, warp_id, + block_bfly_region * EXCHG_REGIONS_PER_BLOCK, inverse); + } + + base_field *twiddles_this_stage = twiddle_cache + 2 * VALS_PER_THREAD - 2; + unsigned num_twiddles_this_stage = 1; + for (unsigned i = 0; i < LOG_VALS_PER_THREAD - 1; i++) { +#pragma unroll + for (unsigned j = 0; j < (1 << i); j++) { + const unsigned exchg_tile_sz = VALS_PER_THREAD >> i; + const unsigned half_exchg_tile_sz = exchg_tile_sz >> 1; + const auto twiddle = twiddles_this_stage[j]; +#pragma unroll + for (unsigned k = 0; k < half_exchg_tile_sz; k++) { + exchg_dit(vals[exchg_tile_sz * j + k], vals[exchg_tile_sz * j + k + half_exchg_tile_sz], twiddle); + } + } + num_twiddles_this_stage <<= 1; + twiddles_this_stage -= num_twiddles_this_stage; + } + + lane_mask = 16; + for (unsigned s = 0; s < 2; s++) { + if (!skip_last_stage || s < 1) { +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + // TODO: Handle these cooperatively? + const auto twiddle = twiddles_this_stage[(2 * i + halfwarp_id) >> (1 - s)]; + exchg_dit(vals[2 * i], vals[2 * i + 1], twiddle); + shfl_xor_bf(vals, i, lane_id, lane_mask); + } + lane_mask >>= 1; + num_twiddles_this_stage <<= 1; + twiddles_this_stage -= num_twiddles_this_stage; + } + } + + if (skip_last_stage) { + auto val0_addr = gmem_out + TILES_PER_WARP * tile_stride * warp_id + 2 * tile_stride * (lane_id >> 4) + 2 * (threadIdx.x & 7) + (lane_id >> 3 & 1); +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + memory::store_cs(val0_addr, vals[2 * i]); + memory::store_cs(val0_addr + tile_stride, vals[2 * i + 1]); + val0_addr += 4 * tile_stride; + } + } else { + auto pair_addr = gmem_out + TILES_PER_WARP * tile_stride * warp_id + tile_stride * (lane_id >> 3) + 2 * (threadIdx.x & 7); +#pragma unroll + for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) { + const uint4 out{vals[2 * i][0], vals[2 * i][1], vals[2 * i + 1][0], vals[2 * i + 1][1]}; + memory::store_cs(reinterpret_cast(pair_addr), out); + pair_addr += 4 * tile_stride; + } + } } } +extern "C" __launch_bounds__(512, 2) __global__ +void n2b_nonfinal_7_or_8_stages_block(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, + const unsigned stride_between_output_arrays, const unsigned start_stage, const unsigned stages_this_launch, + const unsigned log_n, const bool inverse, const unsigned num_ntts, const unsigned log_extension_degree, + const unsigned coset_idx) { + n2b_nonfinal_stages_block<3>(gmem_inputs_matrix, gmem_outputs_matrix, stride_between_input_arrays, stride_between_output_arrays, start_stage, + stages_this_launch == 7, log_n, inverse, num_ntts, log_extension_degree, coset_idx); +} + // Simple, non-optimized kernel used for log_n < 16, to unblock debugging small proofs. extern "C" __launch_bounds__(512, 2) __global__ void n2b_1_stage(const base_field *gmem_inputs_matrix, base_field *gmem_outputs_matrix, const unsigned stride_between_input_arrays, diff --git a/boojum-cuda/src/ntt.rs b/boojum-cuda/src/ntt.rs index e597ff8..f368f31 100644 --- a/boojum-cuda/src/ntt.rs +++ b/boojum-cuda/src/ntt.rs @@ -7,7 +7,7 @@ use cudart::kernel_args; use cudart::result::{CudaResult, CudaResultWrap}; use cudart::slice::DeviceSlice; use cudart::stream::CudaStream; -use cudart_sys::cudaLaunchKernel; +use cudart_sys::{cudaLaunchKernel, dim3}; use crate::context::OMEGA_LOG_ORDER; @@ -25,7 +25,8 @@ extern "C" { log_extension_degree: u32, coset_index: u32, ); - fn n2b_final_7_or_8_stages( + + fn n2b_final_7_stages_warp( inputs_matrix: *const GoldilocksField, outputs_matrix: *mut GoldilocksField, stride_between_input_arrays: u32, @@ -34,11 +35,12 @@ extern "C" { stages_this_launch: u32, log_n: u32, inverse: bool, - blocks_per_ntt: u32, + num_ntts: u32, log_extension_degree: u32, coset_index: u32, ); - fn n2b_final_9_to_12_stages( + + fn n2b_final_8_stages_warp( inputs_matrix: *const GoldilocksField, outputs_matrix: *mut GoldilocksField, stride_between_input_arrays: u32, @@ -47,11 +49,12 @@ extern "C" { stages_this_launch: u32, log_n: u32, inverse: bool, - blocks_per_ntt: u32, + num_ntts: u32, log_extension_degree: u32, coset_index: u32, ); - fn n2b_nonfinal_7_or_8_stages( + + fn n2b_final_9_to_12_stages_block( inputs_matrix: *const GoldilocksField, outputs_matrix: *mut GoldilocksField, stride_between_input_arrays: u32, @@ -60,10 +63,25 @@ extern "C" { stages_this_launch: u32, log_n: u32, inverse: bool, - blocks_per_ntt: u32, + num_ntts: u32, + log_extension_degree: u32, + coset_index: u32, + ); + + fn n2b_nonfinal_7_or_8_stages_block( + inputs_matrix: *const GoldilocksField, + outputs_matrix: *mut GoldilocksField, + stride_between_input_arrays: u32, + stride_between_output_arrays: u32, + start_stage: u32, + stages_this_launch: u32, + log_n: u32, + inverse: bool, + num_ntts: u32, log_extension_degree: u32, coset_index: u32, ); + fn b2n_1_stage( inputs_matrix: *const GoldilocksField, outputs_matrix: *mut GoldilocksField, @@ -77,7 +95,8 @@ extern "C" { log_extension_degree: u32, coset_index: u32, ); - fn b2n_initial_7_or_8_stages( + + fn b2n_initial_7_stages_warp( inputs_matrix: *const GoldilocksField, outputs_matrix: *mut GoldilocksField, stride_between_input_arrays: u32, @@ -86,11 +105,12 @@ extern "C" { stages_this_launch: u32, log_n: u32, inverse: bool, - blocks_per_ntt: u32, + num_ntts: u32, log_extension_degree: u32, coset_index: u32, ); - fn b2n_initial_9_to_12_stages( + + fn b2n_initial_8_stages_warp( inputs_matrix: *const GoldilocksField, outputs_matrix: *mut GoldilocksField, stride_between_input_arrays: u32, @@ -99,11 +119,12 @@ extern "C" { stages_this_launch: u32, log_n: u32, inverse: bool, - blocks_per_ntt: u32, + num_ntts: u32, log_extension_degree: u32, coset_index: u32, ); - fn b2n_noninitial_7_or_8_stages( + + fn b2n_initial_9_to_12_stages_block( inputs_matrix: *const GoldilocksField, outputs_matrix: *mut GoldilocksField, stride_between_input_arrays: u32, @@ -112,7 +133,21 @@ extern "C" { stages_this_launch: u32, log_n: u32, inverse: bool, - blocks_per_ntt: u32, + num_ntts: u32, + log_extension_degree: u32, + coset_index: u32, + ); + + fn b2n_noninitial_7_or_8_stages_block( + inputs_matrix: *const GoldilocksField, + outputs_matrix: *mut GoldilocksField, + stride_between_input_arrays: u32, + stride_between_output_arrays: u32, + start_stage: u32, + stages_this_launch: u32, + log_n: u32, + inverse: bool, + num_ntts: u32, log_extension_degree: u32, coset_index: u32, ); @@ -121,12 +156,14 @@ extern "C" { #[allow(non_camel_case_types)] #[allow(clippy::upper_case_acronyms)] enum KERN { - N2B_FINAL_7_OR_8(u32), - N2B_FINAL_9_TO_12(u32), - N2B_NONFINAL_7_OR_8(u32), - B2N_INITIAL_7_OR_8(u32), - B2N_INITIAL_9_TO_12(u32), - B2N_NONINITIAL_7_OR_8(u32), + N2B_FINAL_7_WARP(u32), + N2B_FINAL_8_WARP(u32), + N2B_FINAL_9_TO_12_BLOCK(u32), + N2B_NONFINAL_7_OR_8_BLOCK(u32), + B2N_INITIAL_7_WARP(u32), + B2N_INITIAL_8_WARP(u32), + B2N_INITIAL_9_TO_12_BLOCK(u32), + B2N_NONINITIAL_7_OR_8_BLOCK(u32), SKIP, } @@ -137,103 +174,103 @@ enum KERN { const PLANS: [[[KERN; 3]; 9]; 2] = [ [ [ - KERN::N2B_NONFINAL_7_OR_8(8), - KERN::N2B_FINAL_7_OR_8(8), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(8), + KERN::N2B_FINAL_8_WARP(8), KERN::SKIP, ], [ - KERN::N2B_NONFINAL_7_OR_8(8), - KERN::N2B_FINAL_9_TO_12(9), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(8), + KERN::N2B_FINAL_9_TO_12_BLOCK(9), KERN::SKIP, ], [ - KERN::N2B_NONFINAL_7_OR_8(8), - KERN::N2B_FINAL_9_TO_12(10), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(8), + KERN::N2B_FINAL_9_TO_12_BLOCK(10), KERN::SKIP, ], [ - KERN::N2B_NONFINAL_7_OR_8(8), - KERN::N2B_FINAL_9_TO_12(11), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(8), + KERN::N2B_FINAL_9_TO_12_BLOCK(11), KERN::SKIP, ], [ - KERN::N2B_NONFINAL_7_OR_8(8), - KERN::N2B_FINAL_9_TO_12(12), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(8), + KERN::N2B_FINAL_9_TO_12_BLOCK(12), KERN::SKIP, ], [ - KERN::N2B_NONFINAL_7_OR_8(7), - KERN::N2B_NONFINAL_7_OR_8(7), - KERN::N2B_FINAL_7_OR_8(7), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(7), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(7), + KERN::N2B_FINAL_7_WARP(7), ], [ - KERN::N2B_NONFINAL_7_OR_8(7), - KERN::N2B_NONFINAL_7_OR_8(7), - KERN::N2B_FINAL_7_OR_8(8), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(7), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(7), + KERN::N2B_FINAL_8_WARP(8), ], [ - KERN::N2B_NONFINAL_7_OR_8(7), - KERN::N2B_NONFINAL_7_OR_8(8), - KERN::N2B_FINAL_7_OR_8(8), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(7), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(8), + KERN::N2B_FINAL_8_WARP(8), ], [ - KERN::N2B_NONFINAL_7_OR_8(8), - KERN::N2B_NONFINAL_7_OR_8(8), - KERN::N2B_FINAL_7_OR_8(8), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(8), + KERN::N2B_NONFINAL_7_OR_8_BLOCK(8), + KERN::N2B_FINAL_8_WARP(8), ], ], [ [ - KERN::B2N_INITIAL_7_OR_8(8), - KERN::B2N_NONINITIAL_7_OR_8(8), + KERN::B2N_INITIAL_8_WARP(8), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(8), KERN::SKIP, ], [ - KERN::B2N_INITIAL_9_TO_12(9), - KERN::B2N_NONINITIAL_7_OR_8(8), + KERN::B2N_INITIAL_9_TO_12_BLOCK(9), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(8), KERN::SKIP, ], [ - KERN::B2N_INITIAL_9_TO_12(10), - KERN::B2N_NONINITIAL_7_OR_8(8), + KERN::B2N_INITIAL_9_TO_12_BLOCK(10), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(8), KERN::SKIP, ], [ - KERN::B2N_INITIAL_9_TO_12(11), - KERN::B2N_NONINITIAL_7_OR_8(8), + KERN::B2N_INITIAL_9_TO_12_BLOCK(11), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(8), KERN::SKIP, ], [ - KERN::B2N_INITIAL_9_TO_12(12), - KERN::B2N_NONINITIAL_7_OR_8(8), + KERN::B2N_INITIAL_9_TO_12_BLOCK(12), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(8), KERN::SKIP, ], [ - KERN::B2N_INITIAL_7_OR_8(7), - KERN::B2N_NONINITIAL_7_OR_8(7), - KERN::B2N_NONINITIAL_7_OR_8(7), + KERN::B2N_INITIAL_7_WARP(7), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(7), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(7), ], [ - KERN::B2N_INITIAL_7_OR_8(8), - KERN::B2N_NONINITIAL_7_OR_8(7), - KERN::B2N_NONINITIAL_7_OR_8(7), + KERN::B2N_INITIAL_8_WARP(8), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(7), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(7), ], [ - KERN::B2N_INITIAL_7_OR_8(8), - KERN::B2N_NONINITIAL_7_OR_8(8), - KERN::B2N_NONINITIAL_7_OR_8(7), + KERN::B2N_INITIAL_8_WARP(8), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(8), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(7), ], [ - KERN::B2N_INITIAL_7_OR_8(8), - KERN::B2N_NONINITIAL_7_OR_8(8), - KERN::B2N_NONINITIAL_7_OR_8(8), + KERN::B2N_INITIAL_8_WARP(8), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(8), + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(8), ], ], ]; #[allow(clippy::too_many_arguments)] fn launch( - nblocks: u32, + nblocks: dim3, nthreads: u32, smem_bytes: usize, stream: &CudaStream, @@ -276,6 +313,11 @@ fn launch( } } +// Each block strides across up to at most NTTS_PER_BLOCK ntts in a batch. +// This is just to boost occupancy and reduce tail effect for larger batches. +// The value here must match NTTS_PER_BLOCK in native/ntt.cu. +const NTTS_PER_BLOCK: u32 = 8; + // Carries out LDE for all cosets in a single launch, which improves saturation for smaller sizes. // results must contain 2^log_extension_degree DeviceAllocationSlices, to hold all the output cosets. #[allow(clippy::too_many_arguments)] @@ -292,7 +334,6 @@ pub fn batch_ntt_internal( coset_index: u32, stream: &CudaStream, ) -> CudaResult<()> { - const PADDED_WARP_SCRATCH_SIZE: usize = (256 / 16) * 17 + 1; assert!(log_n >= 1); assert!((log_n + log_extension_degree) <= OMEGA_LOG_ORDER); let n = 1 << log_n; @@ -326,7 +367,7 @@ pub fn batch_ntt_internal( stride_between_output_arrays }; launch( - blocks, + blocks.into(), threads, 0, stream, @@ -350,11 +391,6 @@ pub fn batch_ntt_internal( let mut stage: u32 = 0; for kernel in plan { let start_stage = stage; - // grid and block size for smem kernels - let blocks_per_ntt_smem: u32 = n / 4096; - let nthreads_smem: u32 = 512; - let smem_bytes: usize = (512 / 32) * PADDED_WARP_SCRATCH_SIZE * 8; - let total_blocks_smem: u32 = blocks_per_ntt_smem * num_ntts; // Raw input pointers let inputs_ptr = if stage == 0 { inputs_ptr_in @@ -366,15 +402,16 @@ pub fn batch_ntt_internal( } else { stride_between_output_arrays }; + let num_chunks = (num_ntts + NTTS_PER_BLOCK - 1) / NTTS_PER_BLOCK; match kernel { - KERN::N2B_FINAL_7_OR_8(stages) => { + KERN::N2B_FINAL_7_WARP(stages) => { stage += stages; launch( - total_blocks_smem, - nthreads_smem, - smem_bytes, + (n / (4 * 128), num_chunks).into(), + 128, + 0, stream, - n2b_final_7_or_8_stages as *const c_void, + n2b_final_7_stages_warp as *const c_void, inputs_ptr, outputs_ptr, stride_between_input_arrays, @@ -383,19 +420,19 @@ pub fn batch_ntt_internal( *stages, log_n, inverse, - blocks_per_ntt_smem, + num_ntts, log_extension_degree, coset_index, ) } - KERN::N2B_FINAL_9_TO_12(stages) => { + KERN::N2B_FINAL_8_WARP(stages) => { stage += stages; launch( - total_blocks_smem, - nthreads_smem, - smem_bytes, + (n / (4 * 256), num_chunks).into(), + 128, + 0, stream, - n2b_final_9_to_12_stages as *const c_void, + n2b_final_8_stages_warp as *const c_void, inputs_ptr, outputs_ptr, stride_between_input_arrays, @@ -404,19 +441,19 @@ pub fn batch_ntt_internal( *stages, log_n, inverse, - blocks_per_ntt_smem, + num_ntts, log_extension_degree, coset_index, ) } - KERN::N2B_NONFINAL_7_OR_8(stages) => { + KERN::N2B_FINAL_9_TO_12_BLOCK(stages) => { stage += stages; launch( - total_blocks_smem, - nthreads_smem, - smem_bytes, + (n / 4096, num_chunks).into(), + 512, + 0, stream, - n2b_nonfinal_7_or_8_stages as *const c_void, + n2b_final_9_to_12_stages_block as *const c_void, inputs_ptr, outputs_ptr, stride_between_input_arrays, @@ -425,19 +462,40 @@ pub fn batch_ntt_internal( *stages, log_n, inverse, - blocks_per_ntt_smem, + num_ntts, log_extension_degree, coset_index, ) } - KERN::B2N_INITIAL_7_OR_8(stages) => { + KERN::N2B_NONFINAL_7_OR_8_BLOCK(stages) => { stage += stages; launch( - n / (4 * 256), + (n / 4096, num_chunks).into(), + 512, + 0, + stream, + n2b_nonfinal_7_or_8_stages_block as *const c_void, + inputs_ptr, + outputs_ptr, + stride_between_input_arrays, + stride_between_output_arrays, + start_stage, + *stages, + log_n, + inverse, + num_ntts, + log_extension_degree, + coset_index, + ) + } + KERN::B2N_INITIAL_7_WARP(stages) => { + stage += stages; + launch( + (n / (4 * 128), num_chunks).into(), 128, 0, stream, - b2n_initial_7_or_8_stages as *const c_void, + b2n_initial_7_stages_warp as *const c_void, inputs_ptr, outputs_ptr, stride_between_input_arrays, @@ -446,19 +504,19 @@ pub fn batch_ntt_internal( *stages, log_n, inverse, - blocks_per_ntt_smem, + num_ntts, log_extension_degree, coset_index, ) } - KERN::B2N_INITIAL_9_TO_12(stages) => { + KERN::B2N_INITIAL_8_WARP(stages) => { stage += stages; launch( - total_blocks_smem, - nthreads_smem, - smem_bytes, + (n / (4 * 256), num_chunks).into(), + 128, + 0, stream, - b2n_initial_9_to_12_stages as *const c_void, + b2n_initial_8_stages_warp as *const c_void, inputs_ptr, outputs_ptr, stride_between_input_arrays, @@ -467,19 +525,40 @@ pub fn batch_ntt_internal( *stages, log_n, inverse, - blocks_per_ntt_smem, + num_ntts, log_extension_degree, coset_index, ) } - KERN::B2N_NONINITIAL_7_OR_8(stages) => { + KERN::B2N_INITIAL_9_TO_12_BLOCK(stages) => { stage += stages; launch( - total_blocks_smem, - nthreads_smem, - smem_bytes, + (n / 4096, num_chunks).into(), + 512, + 0, + stream, + b2n_initial_9_to_12_stages_block as *const c_void, + inputs_ptr, + outputs_ptr, + stride_between_input_arrays, + stride_between_output_arrays, + start_stage, + *stages, + log_n, + inverse, + num_ntts, + log_extension_degree, + coset_index, + ) + } + KERN::B2N_NONINITIAL_7_OR_8_BLOCK(stages) => { + stage += stages; + launch( + (n / 4096, num_chunks).into(), + 512, + 0, stream, - b2n_noninitial_7_or_8_stages as *const c_void, + b2n_noninitial_7_or_8_stages_block as *const c_void, inputs_ptr, outputs_ptr, stride_between_input_arrays, @@ -488,7 +567,7 @@ pub fn batch_ntt_internal( *stages, log_n, inverse, - blocks_per_ntt_smem, + num_ntts, log_extension_degree, coset_index, ) @@ -696,7 +775,7 @@ mod tests { memory_copy_async(&mut outputs_n2b_out_of_place_host, &outputs_device, &stream) .unwrap(); - // // Bitrev to nonbitrev, in-place + // Bitrev to nonbitrev, in-place memory_copy_async(&mut inputs_device, &inputs_bitrev_host, &stream).unwrap(); batch_ntt_in_place( &mut inputs_device, @@ -825,7 +904,7 @@ mod tests { #[test] #[serial] fn correctness_batch_ntt_fwd() { - correctness(24..25, false, 0, 0, 1); + correctness(1..17, false, 0, 0, 2 * NTTS_PER_BLOCK + 3); } #[test]