Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Scalar conversions explicit #335

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/base/bit/abs_bit_mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::base::scalar::Scalar;

pub fn make_abs_bit_mask<S: Scalar>(x: S) -> [u64; 4] {
let (sign, x) = if S::MAX_SIGNED < x { (1, -x) } else { (0, x) };
let mut res: [u64; 4] = x.into();
let mut res: [u64; 4] = x.to_limbs();
res[3] |= sign << 63;
res
}
11 changes: 5 additions & 6 deletions crates/proof-of-sql/src/base/commitment/committable_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ impl<'a, S: Scalar> From<&Column<'a, S>> for CommittableColumn<'a> {
Column::BigInt(ints) => CommittableColumn::BigInt(ints),
Column::Int128(ints) => CommittableColumn::Int128(ints),
Column::Decimal75(precision, scale, decimals) => {
let as_limbs: Vec<_> = decimals.iter().map(RefInto::<[u64; 4]>::ref_into).collect();
let as_limbs: Vec<_> = decimals.iter().map(S::to_limbs).collect();
CommittableColumn::Decimal75(*precision, *scale, as_limbs)
}
Column::Scalar(scalars) => (scalars as &[_]).into(),
Column::VarChar((_, scalars)) => {
let as_limbs: Vec<_> = scalars.iter().map(RefInto::<[u64; 4]>::ref_into).collect();
let as_limbs: Vec<_> = scalars.iter().map(S::to_limbs).collect();
CommittableColumn::VarChar(as_limbs)
}
Column::TimestampTZ(tu, tz, times) => CommittableColumn::TimestampTZ(*tu, *tz, times),
Expand Down Expand Up @@ -144,16 +144,15 @@ impl<'a, S: Scalar> From<&'a OwnedColumn<S>> for CommittableColumn<'a> {
*scale,
decimals
.iter()
.map(Into::<S>::into)
.map(Into::<[u64; 4]>::into)
.map(S::to_limbs)
.collect(),
),
OwnedColumn::Scalar(scalars) => (scalars as &[_]).into(),
OwnedColumn::VarChar(strings) => CommittableColumn::VarChar(
strings
.iter()
.map(Into::<S>::into)
.map(Into::<[u64; 4]>::into)
.map(|val| { val.to_limbs() })
.collect(),
),
OwnedColumn::TimestampTZ(tu, tz, times) => {
Expand Down Expand Up @@ -197,7 +196,7 @@ impl<'a> From<&'a [i128]> for CommittableColumn<'a> {
}
impl<'a, S: Scalar> From<&'a [S]> for CommittableColumn<'a> {
fn from(value: &'a [S]) -> Self {
CommittableColumn::Scalar(value.iter().map(RefInto::<[u64; 4]>::ref_into).collect())
CommittableColumn::Scalar(value.iter().map(S::to_limbs).collect())
}
}
impl<'a> From<&'a [bool]> for CommittableColumn<'a> {
Expand Down
5 changes: 3 additions & 2 deletions crates/proof-of-sql/src/base/proof/transcript_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ impl<T: TranscriptCore> Transcript for T {
&mut self,
messages: impl IntoIterator<Item = &'a S>,
) {
self.extend_as_be::<[u64; 4]>(messages.into_iter().map(RefInto::ref_into));
self.extend_as_be::<[u64; 4]>(messages.into_iter().map(S::to_limbs));
}
fn scalar_challenge_as_be<S: Scalar>(&mut self) -> S {
receive_challenge_as_be::<[u64; 4]>(self).into()
let x = receive_challenge_as_be::<[u64; 4]>(self);
Scalar::from_limbs(x)
}
fn challenge_as_le(&mut self) -> [u8; 32] {
self.raw_challenge()
Expand Down
34 changes: 21 additions & 13 deletions crates/proof-of-sql/src/base/scalar/mont_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,14 @@ impl super::Scalar for Curve25519Scalar {
const ONE: Self = Self(ark_ff::MontFp!("1"));
const TWO: Self = Self(ark_ff::MontFp!("2"));
const TEN: Self = Self(ark_ff::MontFp!("10"));

fn from_limbs(val: [u64; 4]) -> Self {
Self::from(val)
}

fn to_limbs(&self) -> [u64; 4] {
self.into()
}
}

impl<T> TryFrom<MontScalar<T>> for bool
Expand All @@ -441,9 +449,9 @@ where
type Error = ScalarConversionError;
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, (value).to_limbs())
};
if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand All @@ -469,9 +477,9 @@ where
type Error = ScalarConversionError;
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, value.to_limbs())
};
if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand All @@ -493,9 +501,9 @@ where
type Error = ScalarConversionError;
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, value.to_limbs())
};
if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand All @@ -517,9 +525,9 @@ where
type Error = ScalarConversionError;
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, value.to_limbs())
};
if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand All @@ -541,9 +549,9 @@ where
type Error = ScalarConversionError;
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, value.to_limbs())
};
if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand All @@ -567,9 +575,9 @@ where
#[allow(clippy::cast_possible_wrap)]
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, value.to_limbs())
};
if abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand Down Expand Up @@ -601,7 +609,7 @@ where
} else {
num_bigint::Sign::Plus
};
let value_abs: [u64; 4] = (if is_negative { -value } else { value }).into();
let value_abs: [u64; 4] = (if is_negative { -value } else { value }).to_limbs();
let bits: &[u8] = bytemuck::cast_slice(&value_abs);
BigInt::from_bytes_le(sign, bits)
}
Expand Down
9 changes: 6 additions & 3 deletions crates/proof-of-sql/src/base/scalar/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ pub trait Scalar:
+ core::convert::TryInto <i32>
+ core::convert::TryInto <i64>
+ core::convert::TryInto <i128>
+ core::convert::Into<[u64; 4]>
+ core::convert::From<[u64; 4]>
// + core::convert::Into<[u64; 4]>
// + core::convert::From<[u64; 4]>
+ core::cmp::Ord
+ core::ops::Neg<Output = Self>
+ num_traits::Zero
Expand All @@ -47,7 +47,7 @@ pub trait Scalar:
+ ark_std::UniformRand //This enables us to get `Scalar`s as challenges from the transcript
+ num_traits::Inv<Output = Option<Self>> // Note: `inv` should return `None` exactly when the element is zero.
+ core::ops::SubAssign
+ RefInto<[u64; 4]>
// + RefInto<[u64; 4]>
+ for<'a> core::convert::From<&'a String>
+ VarInt
+ core::convert::From<String>
Expand All @@ -71,4 +71,7 @@ pub trait Scalar:
const TWO: Self;
/// 2 + 2 + 2 + 2 + 2
const TEN: Self;

fn from_limbs(val: [u64; 4]) -> Self;
fn to_limbs(&self) -> [u64; 4];
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ impl Scalar for DoryScalar {
const ONE: Self = Self(ark_ff::MontFp!("1"));
const TWO: Self = Self(ark_ff::MontFp!("2"));
const TEN: Self = Self(ark_ff::MontFp!("10"));

fn from_limbs(val: [u64; 4]) -> Self {
Self::from(val)
}

fn to_limbs(&self) -> [u64; 4] {
self.into()
}
}

#[derive(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ pub fn verify_constant_sign_decomposition<S: Scalar>(
&& !dist.has_varying_sign_bit()
);
let lhs = if dist.sign_bit() { -eval } else { eval };
let mut rhs = S::from(dist.constant_part()) * one_eval;
let mut rhs = S::from_limbs(dist.constant_part()) * one_eval;
let mut vary_index = 0;
dist.for_each_abs_varying_bit(|int_index: usize, bit_index: usize| {
let mut mult = [0u64; 4];
mult[int_index] = 1u64 << bit_index;
rhs += S::from(mult) * bit_evals[vary_index];
rhs += S::from_limbs(mult) * bit_evals[vary_index];
vary_index += 1;
});
if lhs == rhs {
Expand Down Expand Up @@ -72,7 +72,7 @@ pub fn verify_constant_abs_decomposition<S: Scalar>(
&& dist.has_varying_sign_bit()
);
let t = one_eval - S::TWO * sign_eval;
if S::from(dist.constant_part()) * t == eval {
if S::from_limbs(dist.constant_part()) * t == eval {
Ok(())
} else {
Err(ProofError::VerificationError {
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof_exprs/range_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn decompose_scalar_to_words<'a, S: Scalar + 'a>(
byte_counts: &mut [u64],
) {
for (i, scalar) in scalars.iter().enumerate() {
let scalar_array: [u64; 4] = (*scalar).into(); // Convert scalar to u64 array
let scalar_array: [u64; 4] = scalar.to_limbs(); // Convert scalar to u64 array
let scalar_bytes_full = cast_slice::<u64, u8>(&scalar_array); // Cast u64 array to u8 slice
let scalar_bytes = &scalar_bytes_full[..31];

Expand Down
8 changes: 4 additions & 4 deletions crates/proof-of-sql/src/sql/proof_exprs/sign_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ fn prove_bit_decomposition<'a, S: Scalar>(
terms.push((S::one(), vec![Box::new(expr)]));

// expr bit decomposition
let const_part = S::from(dist.constant_part());
let const_part = S::from_limbs(dist.constant_part());
if !const_part.is_zero() {
terms.push((-const_part, vec![Box::new(sign_mle)]));
}
Expand All @@ -230,7 +230,7 @@ fn prove_bit_decomposition<'a, S: Scalar>(
let mut mult = [0u64; 4];
mult[int_index] = 1u64 << bit_index;
terms.push((
-S::from(mult),
-S::from_limbs(mult),
vec![Box::new(sign_mle), Box::new(bits[vary_index])],
));
vary_index += 1;
Expand All @@ -252,12 +252,12 @@ fn verify_bit_decomposition<C: Commitment>(
let sign_eval = bit_evals.last().unwrap();
let sign_eval = builder.mle_evaluations.input_one_evaluation - C::Scalar::TWO * *sign_eval;
let mut vary_index = 0;
eval -= sign_eval * C::Scalar::from(dist.constant_part());
eval -= sign_eval * C::Scalar::from_limbs(dist.constant_part());
dist.for_each_abs_varying_bit(|int_index: usize, bit_index: usize| {
let mut mult = [0u64; 4];
mult[int_index] = 1u64 << bit_index;
let bit_eval = bit_evals[vary_index];
eval -= C::Scalar::from(mult) * sign_eval * bit_eval;
eval -= C::Scalar::from_limbs(mult) * sign_eval * bit_eval;
vary_index += 1;
});
builder.produce_sumcheck_subpolynomial_evaluation(&SumcheckSubpolynomialType::Identity, eval);
Expand Down
Loading