Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
CUDA execution refactoring (#21)
Browse files Browse the repository at this point in the history
# What ❔

This PR cuts down the amount of boilerplate code related to launching
CUDA kernels.
It depends on the functionality introduced in
matter-labs/era-cuda#2 .

## Why ❔

Less code, especially unsafe code is a win.

## Checklist

- [x] PR title corresponds to the body of PR (we generate changelog
entries from PRs).
- [x] Tests for the changes have been added / updated.
- [x] Documentation comments have been added / updated.
- [x] Code has been formatted via `cargo fmt` and checked for warning
with `cargo clippy`.
  • Loading branch information
robik75 authored Dec 12, 2023
1 parent b42966a commit 9df96d3
Show file tree
Hide file tree
Showing 19 changed files with 2,086 additions and 3,495 deletions.
2 changes: 1 addition & 1 deletion benches/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ fn poseidon2_cooperative_nodes(c: &mut Criterion<CudaMeasurement>) {
);
}

fn merkle_tree<PoseidonVariant: PoseidonRunnable>(
fn merkle_tree<PoseidonVariant: PoseidonImpl>(
c: &mut Criterion<CudaMeasurement>,
group_name: String,
) {
Expand Down
3 changes: 1 addition & 2 deletions build/gates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ fn generate_rust(descriptions: &[Description]) {
h.push_str(format!("(\"{name}\", {id}),").as_str());
new_line(h);
let kernel_name = format!("evaluate_{name}_kernel");
indent(b, 1);
b.push_str(format!("kernel_binding!({kernel_name});").as_str());
b.push_str(format!("gate_eval_kernel!({kernel_name});").as_str());
new_line(b);
new_line(m);
indent(m, 2);
Expand Down
76 changes: 35 additions & 41 deletions native/barycentric.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,42 @@ namespace goldilocks {

using namespace memory;

using bf = base_field;
using ef = extension_field;

// Helper functions to compute common_factor for precompute_lagrange_coeffs
template <typename COEF_T>
DEVICE_FORCEINLINE void barycentric_precompute_common_factor_impl(const COEF_T *x_ref, COEF_T *common_factor_ref, const base_field coset,
const unsigned count) {
DEVICE_FORCEINLINE void barycentric_precompute_common_factor_impl(const COEF_T *x_ref, COEF_T *common_factor_ref, const bf coset, const unsigned count) {
// common_factor = coset * (X^N - coset^N) / (N * coset^N)
// some math could be done on the CPU, but this is a 1-thread kernel so hopefully negligible
const auto x = *x_ref;
const auto cosetN = base_field::pow(coset, count);
const auto cosetN = bf::pow(coset, count);
const auto xN = COEF_T::pow(x, count);
const auto num = COEF_T::mul(COEF_T::sub(xN, cosetN), coset);
const auto denom = base_field::mul({count, 0}, cosetN);
const auto common_factor = COEF_T::mul(num, base_field::inv(denom));
const auto denom = bf::mul({count, 0}, cosetN);
const auto common_factor = COEF_T::mul(num, bf::inv(denom));
*common_factor_ref = common_factor;
}

EXTERN __global__ void barycentric_precompute_common_factor_at_base_kernel(const base_field *x_ref, base_field *common_factor_ref, const base_field coset,
const unsigned count) {
EXTERN __global__ void barycentric_precompute_common_factor_bf_kernel(const bf *x_ref, bf *common_factor_ref, const bf coset, const unsigned count) {
barycentric_precompute_common_factor_impl(x_ref, common_factor_ref, coset, count);
}

EXTERN __global__ void barycentric_precompute_common_factor_at_ext_kernel(const extension_field *x_ref, extension_field *common_factor_ref,
const base_field coset, const unsigned count) {
EXTERN __global__ void barycentric_precompute_common_factor_ef_kernel(const ef *x_ref, ef *common_factor_ref, const bf coset, const unsigned count) {
barycentric_precompute_common_factor_impl(x_ref, common_factor_ref, coset, count);
}

template <typename T> struct InvBatch {};
template <> struct InvBatch<base_field> {
template <> struct InvBatch<bf> {
static constexpr unsigned INV_BATCH = 10;
};
template <> struct InvBatch<extension_field> {
template <> struct InvBatch<ef> {
static constexpr unsigned INV_BATCH = 6;
};

template <typename COEF_T, typename COEF_SETTER_T>
DEVICE_FORCEINLINE void barycentric_precompute_lagrange_coeffs_impl(const COEF_T *x_ref, const COEF_T *common_factor_ref, const base_field w_inv_step,
const base_field coset, COEF_SETTER_T lagrange_coeffs, const unsigned log_count) {
DEVICE_FORCEINLINE void barycentric_precompute_lagrange_coeffs_impl(const COEF_T *x_ref, const COEF_T *common_factor_ref, const bf w_inv_step, const bf coset,
COEF_SETTER_T lagrange_coeffs, const unsigned log_count) {
constexpr unsigned INV_BATCH = InvBatch<COEF_T>::INV_BATCH;

// per_elem_factor = w^i / (X - coset * w^i)
Expand Down Expand Up @@ -68,7 +68,7 @@ DEVICE_FORCEINLINE void barycentric_precompute_lagrange_coeffs_impl(const COEF_T
if (g < count) {
per_elem_factor_invs[i] = COEF_T::sub(COEF_T::mul(x, w_inv), coset);
if (g + grid_size < count)
w_inv = base_field::mul(w_inv, w_inv_step);
w_inv = bf::mul(w_inv, w_inv_step);
runtime_batch_size++;
}

Expand All @@ -84,25 +84,21 @@ DEVICE_FORCEINLINE void barycentric_precompute_lagrange_coeffs_impl(const COEF_T
lagrange_coeffs.set(g, COEF_T::mul(per_elem_factors[i], common_factor));
}

EXTERN __global__ void barycentric_precompute_lagrange_coeffs_at_base_kernel(const base_field *x_ref, const base_field *common_factor_ref,
const base_field w_inv_step, const base_field coset,
vector_setter<base_field, st_modifier::cs> lagrange_coeffs,
const unsigned log_count) {
barycentric_precompute_lagrange_coeffs_impl<base_field>(x_ref, common_factor_ref, w_inv_step, coset, lagrange_coeffs, log_count);
EXTERN __global__ void barycentric_precompute_lagrange_coeffs_bf_kernel(const bf *x_ref, const bf *common_factor_ref, const bf w_inv_step, const bf coset,
vector_setter<bf, st_modifier::cs> lagrange_coeffs, const unsigned log_count) {
barycentric_precompute_lagrange_coeffs_impl<bf>(x_ref, common_factor_ref, w_inv_step, coset, lagrange_coeffs, log_count);
}

EXTERN __global__ void barycentric_precompute_lagrange_coeffs_at_ext_kernel(const extension_field *x_ref, const extension_field *common_factor_ref,
const base_field w_inv_step, const base_field coset,
ef_double_vector_setter<st_modifier::cs> lagrange_coeffs,
const unsigned log_count) {
barycentric_precompute_lagrange_coeffs_impl<extension_field>(x_ref, common_factor_ref, w_inv_step, coset, lagrange_coeffs, log_count);
EXTERN __global__ void barycentric_precompute_lagrange_coeffs_ef_kernel(const ef *x_ref, const ef *common_factor_ref, const bf w_inv_step, const bf coset,
ef_double_vector_setter<st_modifier::cs> lagrange_coeffs, const unsigned log_count) {
barycentric_precompute_lagrange_coeffs_impl<ef>(x_ref, common_factor_ref, w_inv_step, coset, lagrange_coeffs, log_count);
}

template <typename T> struct ElemsPerThread {};
template <> struct ElemsPerThread<base_field> {
template <> struct ElemsPerThread<bf> {
static constexpr unsigned ELEMS_PER_THREAD = 12;
};
template <> struct ElemsPerThread<extension_field> {
template <> struct ElemsPerThread<ef> {
static constexpr unsigned ELEMS_PER_THREAD = 6;
};

Expand Down Expand Up @@ -141,25 +137,23 @@ DEVICE_FORCEINLINE void batch_barycentric_partial_reduce_impl(const EVAL_GETTER_

// We could also potentially do these with some custom functors passed to cub segmented reduce
EXTERN __launch_bounds__(1024, 1) __global__
void batch_barycentric_partial_reduce_base_at_base_kernel(matrix_getter<base_field, ld_modifier::cs> batch_ys,
vector_getter<base_field, ld_modifier::ca> lagrange_coeffs,
matrix_setter<base_field, st_modifier::cs> partial_sums, const unsigned log_count,
const unsigned num_polys) {
batch_barycentric_partial_reduce_impl<base_field>(batch_ys, lagrange_coeffs, partial_sums, log_count, num_polys);
void batch_barycentric_partial_reduce_bf_bf_kernel(matrix_getter<bf, ld_modifier::cs> batch_ys, vector_getter<bf, ld_modifier::ca> lagrange_coeffs,
matrix_setter<bf, st_modifier::cs> partial_sums, const unsigned log_count, const unsigned num_polys) {
batch_barycentric_partial_reduce_impl<bf>(batch_ys, lagrange_coeffs, partial_sums, log_count, num_polys);
}

EXTERN __launch_bounds__(1024, 1) __global__ void batch_barycentric_partial_reduce_base_at_ext_kernel(matrix_getter<base_field, ld_modifier::cs> batch_ys,
ef_double_vector_getter<ld_modifier::ca> lagrange_coeffs,
ef_double_matrix_setter<st_modifier::cs> partial_sums,
const unsigned log_count, const unsigned num_polys) {
batch_barycentric_partial_reduce_impl<extension_field>(batch_ys, lagrange_coeffs, partial_sums, log_count, num_polys);
EXTERN __launch_bounds__(1024, 1) __global__
void batch_barycentric_partial_reduce_bf_ef_kernel(matrix_getter<bf, ld_modifier::cs> batch_ys, ef_double_vector_getter<ld_modifier::ca> lagrange_coeffs,
ef_double_matrix_setter<st_modifier::cs> partial_sums, const unsigned log_count,
const unsigned num_polys) {
batch_barycentric_partial_reduce_impl<ef>(batch_ys, lagrange_coeffs, partial_sums, log_count, num_polys);
}

EXTERN __launch_bounds__(1024, 1) __global__ void batch_barycentric_partial_reduce_ext_at_ext_kernel(ef_double_matrix_getter<ld_modifier::cs> batch_ys,
ef_double_vector_getter<ld_modifier::ca> lagrange_coeffs,
ef_double_matrix_setter<st_modifier::cs> partial_sums,
const unsigned log_count, const unsigned num_polys) {
batch_barycentric_partial_reduce_impl<extension_field>(batch_ys, lagrange_coeffs, partial_sums, log_count, num_polys);
EXTERN __launch_bounds__(1024, 1) __global__ void batch_barycentric_partial_reduce_ef_ef_kernel(ef_double_matrix_getter<ld_modifier::cs> batch_ys,
ef_double_vector_getter<ld_modifier::ca> lagrange_coeffs,
ef_double_matrix_setter<st_modifier::cs> partial_sums,
const unsigned log_count, const unsigned num_polys) {
batch_barycentric_partial_reduce_impl<ef>(batch_ys, lagrange_coeffs, partial_sums, log_count, num_polys);
}

} // namespace goldilocks
8 changes: 4 additions & 4 deletions native/gates_poseidon.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ DEVICE_FORCEINLINE void poseidon2_internal_matrix(const base_field *variables, c
#define GATE_POSEIDON(variables_offset, witnesses_offset) \
{ \
poseidon_repetition<witnesses_offset>(variables, witnesses, challenge_bases, challenge_powers, quotient_sums, challenges_count, inputs_stride); \
variables += (variables_offset)*inputs_stride; \
witnesses += (witnesses_offset)*inputs_stride; \
variables += (variables_offset) * inputs_stride; \
witnesses += (witnesses_offset) * inputs_stride; \
}

#define GATE_POSEIDON2(variables_offset, witnesses_offset) \
{ \
poseidon2_repetition<witnesses_offset>(variables, witnesses, challenge_bases, challenge_powers, quotient_sums, challenges_count, inputs_stride); \
variables += (variables_offset)*inputs_stride; \
witnesses += (witnesses_offset)*inputs_stride; \
variables += (variables_offset) * inputs_stride; \
witnesses += (witnesses_offset) * inputs_stride; \
}

#define GATE_POSEIDON2_EXTERNAL_MATRIX \
Expand Down
12 changes: 12 additions & 0 deletions native/goldilocks_extension.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ struct __align__(16) extension_field {
}
return result;
}

static DEVICE_FORCEINLINE extension_field shr(const extension_field &x, const unsigned &shift) {
auto a = base_field::shr(x[0], shift);
auto b = base_field::shr(x[1], shift);
return {a, b};
}

static DEVICE_FORCEINLINE extension_field shl(const extension_field &x, const unsigned &shift) {
auto a = base_field::shl(x[0], shift);
auto b = base_field::shl(x[1], shift);
return {a, b};
}
};

template <memory::ld_modifier LD_MODIFIER = memory::ld_modifier::none>
Expand Down
6 changes: 2 additions & 4 deletions native/ops_complex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ template <> struct InvBatch<extension_field> {
static constexpr unsigned INV_BATCH = 6;
};

template<typename T, typename GETTER, typename SETTER>
DEVICE_FORCEINLINE void batch_inv_impl(GETTER src, SETTER dst, const unsigned count) {
template <typename T, typename GETTER, typename SETTER> DEVICE_FORCEINLINE void batch_inv_impl(GETTER src, SETTER dst, const unsigned count) {
constexpr unsigned INV_BATCH = InvBatch<T>::INV_BATCH;

// ints for indexing because some bounds checks count down and check if an index drops below 0
Expand Down Expand Up @@ -216,8 +215,7 @@ EXTERN __global__ void batch_inv_bf_kernel(vector_getter<base_field, ld_modifier
batch_inv_impl<base_field>(src, dst, count);
}

EXTERN __global__ void batch_inv_ef_kernel(ef_double_vector_getter<ld_modifier::cs> src, ef_double_vector_setter<st_modifier::cs> dst,
const unsigned count) {
EXTERN __global__ void batch_inv_ef_kernel(ef_double_vector_getter<ld_modifier::cs> src, ef_double_vector_setter<st_modifier::cs> dst, const unsigned count) {
batch_inv_impl<extension_field>(src, dst, count);
}

Expand Down
Loading

0 comments on commit 9df96d3

Please sign in to comment.