Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
fix h2d copy in precompute_barycentric_bases
Browse files Browse the repository at this point in the history
  • Loading branch information
mcarilli committed Dec 22, 2023
1 parent 5b021fb commit 729a832
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
30 changes: 19 additions & 11 deletions src/primitives/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,15 @@ pub fn mul_assign_complex(
assert_eq!(c0_this_ptr.add(domain_size), c1_this.as_ptr());
}
let this_ptr = c0_this_ptr as *mut VEF;
let mut this_slice: &mut [VEF] =
unsafe { slice::from_raw_parts_mut(this_ptr, domain_size) };
let mut this_slice: &mut [VEF] = unsafe { slice::from_raw_parts_mut(this_ptr, domain_size) };
let this_vector = unsafe { DeviceSlice::from_mut_slice(&mut this_slice) };

let c0_other_ptr = c0_other.as_ptr();
unsafe {
assert_eq!(c0_other_ptr.add(domain_size), c1_other.as_ptr());
}
let other_ptr = c0_other_ptr as *const VEF;
let other_slice: &[VEF] =
unsafe { slice::from_raw_parts(other_ptr, domain_size) };
let other_slice: &[VEF] = unsafe { slice::from_raw_parts(other_ptr, domain_size) };
let other_vector = unsafe { DeviceSlice::from_slice(&other_slice) };

mul_into_x(this_vector, other_vector, get_stream())
Expand Down Expand Up @@ -435,24 +433,34 @@ pub fn precompute_barycentric_bases(
// (X^N - 1)/ N
// evaluations are elems of first coset of the lde
// shift is k*w^0=k where k is multiplicative generator
let mut d_point = svec!(1);
d_point.copy_from_slice(&[point])?;

// The following does NOT work ("the trait `SetByVal` is not implemented" for EF)
// let mut d_point: DVec<EF, _> = svec!(1);
// helpers::set_by_value(d_point.as_mut(), point, get_stream());

// The following sequence creates a &DeviceVariable<EF> for point.
// It's a little ugly but avoids potentially synchronizing copies.
let mut coeffs: DVec<F, _> = svec!(2);
let [c0, c1] = point.into_coeffs_in_base();
helpers::set_by_value(coeffs[..1].as_mut(), c0, get_stream());
helpers::set_by_value(coeffs[1..].as_mut(), c1, get_stream());
let d_point_ef = unsafe { std::slice::from_raw_parts(coeffs.as_ptr() as *const EF, 1) };
let d_point_ef_var = unsafe { DeviceVariable::from_ref(&d_point_ef[0]) };

let mut d_tmp: SVec<EF> = svec!(1);

let (bases, point, common_factor_storage) = unsafe {
let point = &d_point[0];
let tmp_point = &mut d_tmp[0];
let (bases, common_factor_storage) = unsafe {
let v_bases = std::slice::from_raw_parts_mut(bases.as_ptr() as *mut _, domain_size);
let tmp_point = &mut d_tmp[0];
(
DeviceSlice::from_mut_slice(v_bases),
DeviceVariable::from_ref(point),
DeviceVariable::from_mut(tmp_point),
)
};

use boojum_cuda::barycentric::PrecomputeAtExt;
boojum_cuda::barycentric::precompute_lagrange_coeffs::<PrecomputeAtExt>(
point,
d_point_ef_var,
common_factor_storage,
coset,
bases,
Expand Down
2 changes: 1 addition & 1 deletion src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::{

use super::*;

use nvtx::{range_push, range_pop};
use nvtx::{range_pop, range_push};

pub fn gpu_prove_from_external_witness_data<
P: boojum::field::traits::field_like::PrimeFieldLikeVectorized<Base = F>,
Expand Down

0 comments on commit 729a832

Please sign in to comment.