diff --git a/src/primitives/arith.rs b/src/primitives/arith.rs index 0b8a44e..fd01808 100644 --- a/src/primitives/arith.rs +++ b/src/primitives/arith.rs @@ -56,8 +56,7 @@ pub fn mul_assign_complex( assert_eq!(c0_this_ptr.add(domain_size), c1_this.as_ptr()); } let this_ptr = c0_this_ptr as *mut VEF; - let mut this_slice: &mut [VEF] = - unsafe { slice::from_raw_parts_mut(this_ptr, domain_size) }; + let mut this_slice: &mut [VEF] = unsafe { slice::from_raw_parts_mut(this_ptr, domain_size) }; let this_vector = unsafe { DeviceSlice::from_mut_slice(&mut this_slice) }; let c0_other_ptr = c0_other.as_ptr(); @@ -65,8 +64,7 @@ pub fn mul_assign_complex( assert_eq!(c0_other_ptr.add(domain_size), c1_other.as_ptr()); } let other_ptr = c0_other_ptr as *const VEF; - let other_slice: &[VEF] = - unsafe { slice::from_raw_parts(other_ptr, domain_size) }; + let other_slice: &[VEF] = unsafe { slice::from_raw_parts(other_ptr, domain_size) }; let other_vector = unsafe { DeviceSlice::from_slice(&other_slice) }; mul_into_x(this_vector, other_vector, get_stream()) @@ -435,24 +433,34 @@ pub fn precompute_barycentric_bases( // (X^N - 1)/ N // evaluations are elems of first coset of the lde // shift is k*w^0=k where k is multiplicative generator - let mut d_point = svec!(1); - d_point.copy_from_slice(&[point])?; + + // The following does NOT work ("the trait `SetByVal` is not implemented" for EF) + // let mut d_point: DVec = svec!(1); + // helpers::set_by_value(d_point.as_mut(), point, get_stream()); + + // The following sequence creates a &DeviceVariable for point. + // It's a little ugly but avoids potentially synchronizing copies. + let mut coeffs: DVec = svec!(2); + let [c0, c1] = point.into_coeffs_in_base(); + helpers::set_by_value(coeffs[..1].as_mut(), c0, get_stream()); + helpers::set_by_value(coeffs[1..].as_mut(), c1, get_stream()); + let d_point_ef = unsafe { std::slice::from_raw_parts(coeffs.as_ptr() as *const EF, 1) }; + let d_point_ef_var = unsafe { DeviceVariable::from_ref(&d_point_ef[0]) }; let mut d_tmp: SVec = svec!(1); - let (bases, point, common_factor_storage) = unsafe { - let point = &d_point[0]; - let tmp_point = &mut d_tmp[0]; + let (bases, common_factor_storage) = unsafe { let v_bases = std::slice::from_raw_parts_mut(bases.as_ptr() as *mut _, domain_size); + let tmp_point = &mut d_tmp[0]; ( DeviceSlice::from_mut_slice(v_bases), - DeviceVariable::from_ref(point), DeviceVariable::from_mut(tmp_point), ) }; + use boojum_cuda::barycentric::PrecomputeAtExt; boojum_cuda::barycentric::precompute_lagrange_coeffs::( - point, + d_point_ef_var, common_factor_storage, coset, bases, diff --git a/src/prover.rs b/src/prover.rs index d6bfe6f..1aded3f 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -32,7 +32,7 @@ use crate::{ use super::*; -use nvtx::{range_push, range_pop}; +use nvtx::{range_pop, range_push}; pub fn gpu_prove_from_external_witness_data< P: boojum::field::traits::field_like::PrimeFieldLikeVectorized,