Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
Supporting diffs for batched NTT chunking and twiddle persistence (#25)
Browse files Browse the repository at this point in the history
Required by matter-labs/era-shivini#31

## Checklist

- [x] PR title corresponds to the body of PR (we generate changelog
entries from PRs).
- [x] Tests for the changes have been added / updated.
- [x] Documentation comments have been added / updated.
- [x] Code has been formatted via `cargo fmt` and `cargo lint`.
  • Loading branch information
mcarilli authored Feb 13, 2024
1 parent 97169a9 commit 34f3a51
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 44 deletions.
34 changes: 17 additions & 17 deletions native/ntt_b2n.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ DEVICE_FORCEINLINE void b2n_initial_stages_warp(const base_field *gmem_inputs_ma
for (unsigned ntt_idx = 0; ntt_idx < bound; ntt_idx++, gmem_in += stride_between_input_arrays, gmem_out += stride_between_output_arrays) {
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
const auto in = memory::load_cs(reinterpret_cast<const uint4 *>(gmem_in + 64 * i + 2 * lane_id));
const auto in = memory::load_cg(reinterpret_cast<const uint4 *>(gmem_in + 64 * i + 2 * lane_id));
vals[2 * i][0] = in.x;
vals[2 * i][1] = in.y;
vals[2 * i + 1][0] = in.z;
Expand Down Expand Up @@ -89,8 +89,8 @@ DEVICE_FORCEINLINE void b2n_initial_stages_warp(const base_field *gmem_inputs_ma
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
// This output pattern (resulting from the above shfls) is nice, but not obvious.
// To see why it works, sketch the shfl stages on paper.
memory::store_cs(gmem_out + 64 * i + lane_id, vals[2 * i]);
memory::store_cs(gmem_out + 64 * i + lane_id + 32, vals[2 * i + 1]);
memory::store_cg(gmem_out + 64 * i + lane_id, vals[2 * i]);
memory::store_cg(gmem_out + 64 * i + lane_id + 32, vals[2 * i + 1]);
}
}
}
Expand Down Expand Up @@ -146,7 +146,7 @@ DEVICE_FORCEINLINE void b2n_initial_stages_block(const base_field *gmem_inputs_m
for (unsigned ntt_idx = 0; ntt_idx < bound; ntt_idx++, gmem_in += stride_between_input_arrays, gmem_out += stride_between_output_arrays) {
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
const auto pair = memory::load_cs(reinterpret_cast<const uint4 *>(gmem_in + 64 * i + 2 * lane_id));
const auto pair = memory::load_cg(reinterpret_cast<const uint4 *>(gmem_in + 64 * i + 2 * lane_id));
vals[2 * i][0] = pair.x;
vals[2 * i][1] = pair.y;
vals[2 * i + 1][0] = pair.z;
Expand Down Expand Up @@ -306,8 +306,8 @@ DEVICE_FORCEINLINE void b2n_initial_stages_block(const base_field *gmem_inputs_m

#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
memory::store_cs(gmem_out + 4 * i * VALS_PER_WARP, vals[2 * i]);
memory::store_cs(gmem_out + (4 * i + 2) * VALS_PER_WARP, vals[2 * i + 1]);
memory::store_cg(gmem_out + 4 * i * VALS_PER_WARP, vals[2 * i]);
memory::store_cg(gmem_out + (4 * i + 2) * VALS_PER_WARP, vals[2 * i + 1]);
}
}
}
Expand Down Expand Up @@ -373,15 +373,15 @@ DEVICE_FORCEINLINE void b2n_noninitial_stages_block(const base_field *gmem_input
auto val0_addr = gmem_in + TILES_PER_WARP * tile_stride * warp_id + 2 * tile_stride * (lane_id >> 4) + 2 * (threadIdx.x & 7) + (lane_id >> 3 & 1);
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
vals[2 * i] = memory::load_cs(val0_addr);
vals[2 * i + 1] = memory::load_cs(val0_addr + tile_stride);
vals[2 * i] = memory::load_cg(val0_addr);
vals[2 * i + 1] = memory::load_cg(val0_addr + tile_stride);
val0_addr += 4 * tile_stride;
}
} else {
auto pair_addr = gmem_in + TILES_PER_WARP * tile_stride * warp_id + tile_stride * (lane_id >> 3) + 2 * (threadIdx.x & 7);
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
const auto pair = memory::load_cs(reinterpret_cast<const uint4 *>(pair_addr));
const auto pair = memory::load_cg(reinterpret_cast<const uint4 *>(pair_addr));
vals[2 * i][0] = pair.x;
vals[2 * i][1] = pair.y;
vals[2 * i + 1][0] = pair.z;
Expand Down Expand Up @@ -512,16 +512,16 @@ DEVICE_FORCEINLINE void b2n_noninitial_stages_block(const base_field *gmem_input
}
auto power_of_g0 = get_power_of_g(idx0, true);
auto power_of_g1 = get_power_of_g(idx1, true);
memory::store_cs(gmem_out - gmem_out_offset + idx0, base_field::mul(val0, power_of_g0));
memory::store_cs(gmem_out - gmem_out_offset + idx1, base_field::mul(val1, power_of_g1));
memory::store_cg(gmem_out - gmem_out_offset + idx0, base_field::mul(val0, power_of_g0));
memory::store_cg(gmem_out - gmem_out_offset + idx1, base_field::mul(val1, power_of_g1));
}
*scratch = tmp;
__syncwarp();
} else {
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
memory::store_cs(gmem_out + 4 * i * tile_stride * WARPS_PER_BLOCK, vals[2 * i]);
memory::store_cs(gmem_out + (4 * i + 2) * tile_stride * WARPS_PER_BLOCK, vals[2 * i + 1]);
memory::store_cg(gmem_out + 4 * i * tile_stride * WARPS_PER_BLOCK, vals[2 * i]);
memory::store_cg(gmem_out + (4 * i + 2) * tile_stride * WARPS_PER_BLOCK, vals[2 * i + 1]);
}
}
}
Expand Down Expand Up @@ -556,8 +556,8 @@ extern "C" __launch_bounds__(512, 2) __global__
base_field *gmem_output = gmem_outputs_matrix + ntt_idx * stride_between_output_arrays;

const auto twiddle = get_twiddle(inverse, exchg_region);
auto a = memory::load_cs(gmem_input + a_idx);
auto b = memory::load_cs(gmem_input + b_idx);
auto a = memory::load_cg(gmem_input + a_idx);
auto b = memory::load_cg(gmem_input + b_idx);

if ((start_stage == 0) && log_extension_degree && !inverse) {
const unsigned a_idx_brev = __brev(a_idx) >> (32 - log_n);
Expand Down Expand Up @@ -589,6 +589,6 @@ extern "C" __launch_bounds__(512, 2) __global__
}
}

memory::store_cs(gmem_output + a_idx, a);
memory::store_cs(gmem_output + b_idx, b);
memory::store_cg(gmem_output + a_idx, a);
memory::store_cg(gmem_output + b_idx, b);
}
34 changes: 17 additions & 17 deletions native/ntt_n2b.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ DEVICE_FORCEINLINE void n2b_final_stages_warp(const base_field *gmem_inputs_matr
for (unsigned ntt_idx = 0; ntt_idx < bound; ntt_idx++, gmem_in += stride_between_input_arrays, gmem_out += stride_between_output_arrays) {
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
vals[2 * i] = memory::load_cs(gmem_in + 64 * i + lane_id);
vals[2 * i + 1] = memory::load_cs(gmem_in + 64 * i + lane_id + 32);
vals[2 * i] = memory::load_cg(gmem_in + 64 * i + lane_id);
vals[2 * i + 1] = memory::load_cg(gmem_in + 64 * i + lane_id + 32);
}

base_field *twiddles_this_stage = twiddle_cache + VALS_PER_WARP - 2;
Expand Down Expand Up @@ -84,7 +84,7 @@ DEVICE_FORCEINLINE void n2b_final_stages_warp(const base_field *gmem_inputs_matr
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
const uint4 out{scratch[2 * i][0], scratch[2 * i][1], scratch[2 * i + 1][0], scratch[2 * i + 1][1]};
memory::store_cs(reinterpret_cast<uint4 *>(gmem_out + 64 * i + 2 * lane_id), out);
memory::store_cg(reinterpret_cast<uint4 *>(gmem_out + 64 * i + 2 * lane_id), out);
}
#pragma unroll
for (unsigned i = 0; i < VALS_PER_THREAD; i++)
Expand All @@ -94,7 +94,7 @@ DEVICE_FORCEINLINE void n2b_final_stages_warp(const base_field *gmem_inputs_matr
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
const uint4 out{vals[2 * i][0], vals[2 * i][1], vals[2 * i + 1][0], vals[2 * i + 1][1]};
memory::store_cs(reinterpret_cast<uint4 *>(gmem_out + 64 * i + 2 * lane_id), out);
memory::store_cg(reinterpret_cast<uint4 *>(gmem_out + 64 * i + 2 * lane_id), out);
}
}
}
Expand Down Expand Up @@ -150,8 +150,8 @@ DEVICE_FORCEINLINE void n2b_final_stages_block(const base_field *gmem_inputs_mat
for (unsigned ntt_idx = 0; ntt_idx < bound; ntt_idx++, gmem_in += stride_between_input_arrays, gmem_out += stride_between_output_arrays) {
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
vals[2 * i] = memory::load_cs(gmem_in + 4 * i * VALS_PER_WARP);
vals[2 * i + 1] = memory::load_cs(gmem_in + (4 * i + 2) * VALS_PER_WARP);
vals[2 * i] = memory::load_cg(gmem_in + 4 * i * VALS_PER_WARP);
vals[2 * i + 1] = memory::load_cg(gmem_in + (4 * i + 2) * VALS_PER_WARP);
}

const unsigned stages_to_skip = MAX_STAGES_THIS_LAUNCH - stages_this_launch;
Expand Down Expand Up @@ -295,7 +295,7 @@ DEVICE_FORCEINLINE void n2b_final_stages_block(const base_field *gmem_inputs_mat
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
const uint4 out{scratch[2 * i][0], scratch[2 * i][1], scratch[2 * i + 1][0], scratch[2 * i + 1][1]};
memory::store_cs(reinterpret_cast<uint4 *>(gmem_out + 64 * i + 2 * lane_id), out);
memory::store_cg(reinterpret_cast<uint4 *>(gmem_out + 64 * i + 2 * lane_id), out);
}
#pragma unroll
for (unsigned i = 0; i < VALS_PER_THREAD; i++)
Expand All @@ -305,7 +305,7 @@ DEVICE_FORCEINLINE void n2b_final_stages_block(const base_field *gmem_inputs_mat
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
const uint4 out{vals[2 * i][0], vals[2 * i][1], vals[2 * i + 1][0], vals[2 * i + 1][1]};
memory::store_cs(reinterpret_cast<uint4 *>(gmem_out + 64 * i + 2 * lane_id), out);
memory::store_cg(reinterpret_cast<uint4 *>(gmem_out + 64 * i + 2 * lane_id), out);
}
}
}
Expand Down Expand Up @@ -362,8 +362,8 @@ DEVICE_FORCEINLINE void n2b_nonfinal_stages_block(const base_field *gmem_inputs_
for (unsigned ntt_idx = 0; ntt_idx < bound; ntt_idx++, gmem_in += stride_between_input_arrays, gmem_out += stride_between_output_arrays) {
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
vals[2 * i] = memory::load_cs(gmem_in + 4 * i * tile_stride * WARPS_PER_BLOCK);
vals[2 * i + 1] = memory::load_cs(gmem_in + (4 * i + 2) * tile_stride * WARPS_PER_BLOCK);
vals[2 * i] = memory::load_cg(gmem_in + 4 * i * tile_stride * WARPS_PER_BLOCK);
vals[2 * i + 1] = memory::load_cg(gmem_in + (4 * i + 2) * tile_stride * WARPS_PER_BLOCK);
}

if ((start_stage == 0) && log_extension_degree && !inverse) {
Expand Down Expand Up @@ -500,16 +500,16 @@ DEVICE_FORCEINLINE void n2b_nonfinal_stages_block(const base_field *gmem_inputs_
auto val0_addr = gmem_out + TILES_PER_WARP * tile_stride * warp_id + 2 * tile_stride * (lane_id >> 4) + 2 * (threadIdx.x & 7) + (lane_id >> 3 & 1);
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
memory::store_cs(val0_addr, vals[2 * i]);
memory::store_cs(val0_addr + tile_stride, vals[2 * i + 1]);
memory::store_cg(val0_addr, vals[2 * i]);
memory::store_cg(val0_addr + tile_stride, vals[2 * i + 1]);
val0_addr += 4 * tile_stride;
}
} else {
auto pair_addr = gmem_out + TILES_PER_WARP * tile_stride * warp_id + tile_stride * (lane_id >> 3) + 2 * (threadIdx.x & 7);
#pragma unroll
for (unsigned i = 0; i < PAIRS_PER_THREAD; i++) {
const uint4 out{vals[2 * i][0], vals[2 * i][1], vals[2 * i + 1][0], vals[2 * i + 1][1]};
memory::store_cs(reinterpret_cast<uint4 *>(pair_addr), out);
memory::store_cg(reinterpret_cast<uint4 *>(pair_addr), out);
pair_addr += 4 * tile_stride;
}
}
Expand Down Expand Up @@ -545,8 +545,8 @@ extern "C" __launch_bounds__(512, 2) __global__
base_field *gmem_output = gmem_outputs_matrix + ntt_idx * stride_between_output_arrays;

const auto twiddle = get_twiddle(inverse, exchg_region);
auto a = memory::load_cs(gmem_input + a_idx);
auto b = memory::load_cs(gmem_input + b_idx);
auto a = memory::load_cg(gmem_input + a_idx);
auto b = memory::load_cg(gmem_input + b_idx);

if ((start_stage == 0) && log_extension_degree && !inverse) {
if (coset_idx) {
Expand Down Expand Up @@ -578,6 +578,6 @@ extern "C" __launch_bounds__(512, 2) __global__
}
}

memory::store_cs(gmem_output + a_idx, a);
memory::store_cs(gmem_output + b_idx, b);
memory::store_cg(gmem_output + a_idx, a);
memory::store_cg(gmem_output + b_idx, b);
}
20 changes: 10 additions & 10 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,16 @@ fn generate_powers_dev(
}

pub struct Context {
powers_of_w_fine: DeviceAllocation<GoldilocksField>,
powers_of_w_coarse: DeviceAllocation<GoldilocksField>,
powers_of_w_fine_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
powers_of_w_coarse_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
powers_of_w_inv_fine_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
powers_of_w_inv_coarse_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
powers_of_g_f_fine: DeviceAllocation<GoldilocksField>,
powers_of_g_f_coarse: DeviceAllocation<GoldilocksField>,
powers_of_g_i_fine: DeviceAllocation<GoldilocksField>,
powers_of_g_i_coarse: DeviceAllocation<GoldilocksField>,
pub powers_of_w_fine: DeviceAllocation<GoldilocksField>,
pub powers_of_w_coarse: DeviceAllocation<GoldilocksField>,
pub powers_of_w_fine_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
pub powers_of_w_coarse_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
pub powers_of_w_inv_fine_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
pub powers_of_w_inv_coarse_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
pub powers_of_g_f_fine: DeviceAllocation<GoldilocksField>,
pub powers_of_g_f_coarse: DeviceAllocation<GoldilocksField>,
pub powers_of_g_i_fine: DeviceAllocation<GoldilocksField>,
pub powers_of_g_i_coarse: DeviceAllocation<GoldilocksField>,
}

impl Context {
Expand Down

0 comments on commit 34f3a51

Please sign in to comment.