use merlin::Transcript;
+use rand_core::{CryptoRng, RngCore};
+#[cfg(feature = "serde")]
+use serde::{Deserialize, Serialize};
+
+use core::{fmt, mem};
+
+#[cfg(feature = "serde")]
+use crate::serde::{ScalarHelper, VecHelper};
+use crate::{
+ alloc::{vec, Vec},
+ encryption::ExtendedCiphertext,
+ group::Group,
+ proofs::{TranscriptForGroup, VerificationError},
+ Ciphertext, PublicKey, SecretKey,
+};
+
+struct Ring<'a, G: Group> {
+ index: usize,
+ admissible_values: &'a [G::Element],
+ ciphertext: Ciphertext<G>,
+
+ transcript: Transcript,
+ responses: &'a mut [G::Scalar],
+ terminal_commitments: (G::Element, G::Element),
+
+ value_index: usize,
+ discrete_log: SecretKey<G>,
+ random_scalar: SecretKey<G>,
+}
+
+impl<G: Group> fmt::Debug for Ring<'_, G> {
+ fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+ formatter
+ .debug_struct("Ring")
+ .field("index", &self.index)
+ .field("admissible_values", &self.admissible_values)
+ .field("ciphertext", &self.ciphertext)
+ .field("responses", &self.responses)
+ .field("terminal_commitments", &self.terminal_commitments)
+ .finish()
+ }
+}
+
+impl<'a, G: Group> Ring<'a, G> {
+ #[allow(clippy::too_many_arguments)] fn new<R: CryptoRng + RngCore>(
+ index: usize,
+ log_base: G::Element,
+ ciphertext: ExtendedCiphertext<G>,
+ admissible_values: &'a [G::Element],
+ value_index: usize,
+ transcript: &Transcript,
+ responses: &'a mut [G::Scalar],
+ rng: &mut R,
+ ) -> Self {
+ assert!(
+ !admissible_values.is_empty(),
+ "No admissible values supplied"
+ );
+ assert!(
+ value_index < admissible_values.len(),
+ "Specified value index is out of bounds"
+ );
+ debug_assert_eq!(
+ responses.len(),
+ admissible_values.len(),
+ "Number of responses doesn't match number of admissible values"
+ );
+
+ let random_element = ciphertext.inner.random_element;
+ let blinded_value = ciphertext.inner.blinded_element;
+ debug_assert!(
+ {
+ let expected_blinded_value = log_base * ciphertext.random_scalar.expose_scalar()
+ + admissible_values[value_index];
+ expected_blinded_value == blinded_value
+ },
+ "Specified ciphertext does not match the specified `value_index`"
+ );
+
+ let mut transcript = transcript.clone();
+ transcript.start_proof(b"ring_enc");
+ transcript.append_message(b"enc", &ciphertext.inner.to_bytes());
+ transcript.append_u64(b"i", index as u64);
+
+ let random_scalar = SecretKey::<G>::generate(rng);
+ let mut commitments = (
+ G::mul_generator(random_scalar.expose_scalar()),
+ log_base * random_scalar.expose_scalar(),
+ );
+
+ let it = admissible_values.iter().enumerate().skip(value_index + 1);
+ for (eq_index, &admissible_value) in it {
+ let mut eq_transcript = transcript.clone();
+ eq_transcript.append_u64(b"j", eq_index as u64 - 1);
+ eq_transcript.append_element::<G>(b"R_G", &commitments.0);
+ eq_transcript.append_element::<G>(b"R_K", &commitments.1);
+ let challenge = eq_transcript.challenge_scalar::<G>(b"c");
+
+ let response = G::generate_scalar(rng);
+ responses[eq_index] = response;
+ let dh_element = blinded_value - admissible_value;
+ commitments = (
+ G::mul_generator(&response) - random_element * &challenge,
+ G::multi_mul([&response, &-challenge], [log_base, dh_element]),
+ );
+ }
+
+ Self {
+ index,
+ value_index,
+ admissible_values,
+ ciphertext: ciphertext.inner,
+ transcript,
+ responses,
+ terminal_commitments: commitments,
+ discrete_log: ciphertext.random_scalar,
+ random_scalar,
+ }
+ }
+
+ fn aggregate<R: CryptoRng + RngCore>(
+ rings: Vec<Self>,
+ log_base: G::Element,
+ transcript: &mut Transcript,
+ rng: &mut R,
+ ) -> G::Scalar {
+ debug_assert!(
+ rings.iter().enumerate().all(|(i, ring)| i == ring.index),
+ "Rings have bogus indexes"
+ );
+
+ for ring in &rings {
+ let commitments = &ring.terminal_commitments;
+ transcript.append_element::<G>(b"R_G", &commitments.0);
+ transcript.append_element::<G>(b"R_K", &commitments.1);
+ }
+
+ let common_challenge = transcript.challenge_scalar::<G>(b"c");
+ for ring in rings {
+ ring.finalize(log_base, common_challenge, rng);
+ }
+ common_challenge
+ }
+
+ fn finalize<R: CryptoRng + RngCore>(
+ self,
+ log_base: G::Element,
+ common_challenge: G::Scalar,
+ rng: &mut R,
+ ) {
+ let mut challenge = common_challenge;
+ let it = self.admissible_values[..self.value_index]
+ .iter()
+ .enumerate();
+ for (eq_index, &admissible_value) in it {
+ let response = G::generate_scalar(rng);
+ self.responses[eq_index] = response;
+ let dh_element = self.ciphertext.blinded_element - admissible_value;
+ let commitments = (
+ G::mul_generator(&response) - self.ciphertext.random_element * &challenge,
+ G::multi_mul([&response, &-challenge], [log_base, dh_element]),
+ );
+
+ let mut eq_transcript = self.transcript.clone();
+ eq_transcript.append_u64(b"j", eq_index as u64);
+ eq_transcript.append_element::<G>(b"R_G", &commitments.0);
+ eq_transcript.append_element::<G>(b"R_K", &commitments.1);
+ challenge = eq_transcript.challenge_scalar::<G>(b"c");
+ }
+
+ debug_assert_eq!(self.responses[self.value_index], G::Scalar::from(0_u64));
+ self.responses[self.value_index] =
+ challenge * self.discrete_log.expose_scalar() + self.random_scalar.expose_scalar();
+ }
+}
+
+#[derive(Debug, Clone)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "serde", serde(bound = ""))]
+pub struct RingProof<G: Group> {
+ #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
+ common_challenge: G::Scalar,
+ #[cfg_attr(feature = "serde", serde(with = "VecHelper::<ScalarHelper<G>, 2>"))]
+ ring_responses: Vec<G::Scalar>,
+}
+
+impl<G: Group> RingProof<G> {
+ fn initialize_transcript(transcript: &mut Transcript, receiver: &PublicKey<G>) {
+ transcript.start_proof(b"multi_ring_enc");
+ transcript.append_element_bytes(b"K", receiver.as_bytes());
+ }
+
+ pub(crate) fn new(common_challenge: G::Scalar, ring_responses: Vec<G::Scalar>) -> Self {
+ Self {
+ common_challenge,
+ ring_responses,
+ }
+ }
+
+ pub(crate) fn verify<'a>(
+ &self,
+ receiver: &PublicKey<G>,
+ admissible_values: impl Iterator<Item = &'a [G::Element]> + Clone,
+ ciphertexts: impl Iterator<Item = Ciphertext<G>>,
+ transcript: &mut Transcript,
+ ) -> Result<(), VerificationError> {
+ let total_rings_size: usize = admissible_values.clone().map(<[_]>::len).sum();
+ VerificationError::check_lengths(
+ "items in all rings",
+ self.total_rings_size(),
+ total_rings_size,
+ )?;
+
+ Self::initialize_transcript(transcript, receiver);
+ let initial_ring_transcript = transcript.clone();
+
+ let it = admissible_values.zip(ciphertexts).enumerate();
+ let mut starting_response = 0;
+ for (ring_index, (values, ciphertext)) in it {
+ let mut challenge = self.common_challenge;
+ let mut commitments = (G::generator(), G::generator());
+
+ let mut ring_transcript = initial_ring_transcript.clone();
+ ring_transcript.start_proof(b"ring_enc");
+ ring_transcript.append_message(b"enc", &ciphertext.to_bytes());
+ ring_transcript.append_u64(b"i", ring_index as u64);
+
+ for (eq_index, (&admissible_value, response)) in values
+ .iter()
+ .zip(&self.ring_responses[starting_response..])
+ .enumerate()
+ {
+ let dh_element = ciphertext.blinded_element - admissible_value;
+ let neg_challenge = -challenge;
+
+ commitments = (
+ G::vartime_double_mul_generator(
+ &neg_challenge,
+ ciphertext.random_element,
+ response,
+ ),
+ G::vartime_multi_mul(
+ [response, &neg_challenge],
+ [receiver.as_element(), dh_element],
+ ),
+ );
+
+ if eq_index + 1 < values.len() {
+ let mut eq_transcript = ring_transcript.clone();
+ eq_transcript.append_u64(b"j", eq_index as u64);
+ eq_transcript.append_element::<G>(b"R_G", &commitments.0);
+ eq_transcript.append_element::<G>(b"R_K", &commitments.1);
+ challenge = eq_transcript.challenge_scalar::<G>(b"c");
+ }
+ }
+
+ starting_response += values.len();
+ transcript.append_element::<G>(b"R_G", &commitments.0);
+ transcript.append_element::<G>(b"R_K", &commitments.1);
+ }
+
+ let expected_challenge = transcript.challenge_scalar::<G>(b"c");
+ if expected_challenge == self.common_challenge {
+ Ok(())
+ } else {
+ Err(VerificationError::ChallengeMismatch)
+ }
+ }
+
+ pub(crate) fn total_rings_size(&self) -> usize {
+ self.ring_responses.len()
+ }
+
+ pub fn to_bytes(&self) -> Vec<u8> {
+ let mut bytes = vec![0_u8; G::SCALAR_SIZE * (1 + self.total_rings_size())];
+ G::serialize_scalar(&self.common_challenge, &mut bytes[..G::SCALAR_SIZE]);
+
+ let chunks = bytes[G::SCALAR_SIZE..].chunks_mut(G::SCALAR_SIZE);
+ for (response, buffer) in self.ring_responses.iter().zip(chunks) {
+ G::serialize_scalar(response, buffer);
+ }
+ bytes
+ }
+
+ #[allow(clippy::missing_panics_doc)] pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
+ if bytes.len() % G::SCALAR_SIZE != 0 || bytes.len() < 3 * G::SCALAR_SIZE {
+ return None;
+ }
+ let common_challenge = G::deserialize_scalar(&bytes[..G::SCALAR_SIZE])?;
+
+ let ring_responses: Option<Vec<_>> = bytes[G::SCALAR_SIZE..]
+ .chunks(G::SCALAR_SIZE)
+ .map(G::deserialize_scalar)
+ .collect();
+ let ring_responses = ring_responses?;
+ debug_assert!(ring_responses.len() >= 2);
+
+ Some(Self {
+ common_challenge,
+ ring_responses,
+ })
+ }
+}
+
+#[doc(hidden)] pub struct RingProofBuilder<'a, G: Group, R> {
+ receiver: &'a PublicKey<G>,
+ transcript: &'a mut Transcript,
+ rings: Vec<Ring<'a, G>>,
+ ring_responses: &'a mut [G::Scalar],
+ rng: &'a mut R,
+}
+
+impl<G: Group, R: fmt::Debug> fmt::Debug for RingProofBuilder<'_, G, R> {
+ fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+ formatter
+ .debug_struct("RingProofBuilder")
+ .field("receiver", self.receiver)
+ .field("rings", &self.rings)
+ .field("rng", self.rng)
+ .finish()
+ }
+}
+
+impl<'a, G: Group, R: RngCore + CryptoRng> RingProofBuilder<'a, G, R> {
+ pub fn new(
+ receiver: &'a PublicKey<G>,
+ ring_count: usize,
+ ring_responses: &'a mut [G::Scalar],
+ transcript: &'a mut Transcript,
+ rng: &'a mut R,
+ ) -> Self {
+ RingProof::<G>::initialize_transcript(transcript, receiver);
+ Self {
+ receiver,
+ transcript,
+ rings: Vec::with_capacity(ring_count),
+ ring_responses,
+ rng,
+ }
+ }
+
+ pub fn add_value(
+ &mut self,
+ admissible_values: &'a [G::Element],
+ value_index: usize,
+ ) -> ExtendedCiphertext<G> {
+ let ext_ciphertext =
+ ExtendedCiphertext::new(admissible_values[value_index], self.receiver, self.rng);
+ self.add_precomputed_value(ext_ciphertext.clone(), admissible_values, value_index);
+ ext_ciphertext
+ }
+
+ pub(crate) fn add_precomputed_value(
+ &mut self,
+ ciphertext: ExtendedCiphertext<G>,
+ admissible_values: &'a [G::Element],
+ value_index: usize,
+ ) {
+ let ring_responses = mem::take(&mut self.ring_responses);
+ let (responses_for_ring, rest) = ring_responses.split_at_mut(admissible_values.len());
+ self.ring_responses = rest;
+
+ let ring = Ring::new(
+ self.rings.len(),
+ self.receiver.as_element(),
+ ciphertext,
+ admissible_values,
+ value_index,
+ &*self.transcript,
+ responses_for_ring,
+ self.rng,
+ );
+ self.rings.push(ring);
+ }
+
+ pub fn build(self) -> G::Scalar {
+ debug_assert!(
+ self.ring_responses.is_empty(),
+ "Not all ring_responses were used"
+ );
+ Ring::aggregate(
+ self.rings,
+ self.receiver.as_element(),
+ self.transcript,
+ self.rng,
+ )
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use rand::{thread_rng, Rng};
+ use test_casing::test_casing;
+
+ use core::iter;
+
+ use super::*;
+ use crate::{
+ curve25519::{ristretto::RistrettoPoint, scalar::Scalar as Scalar25519, traits::Identity},
+ group::{ElementOps, Ristretto},
+ };
+
+ type Keypair = crate::Keypair<Ristretto>;
+
+ #[test]
+ fn single_ring_with_2_elements_works() {
+ let mut rng = thread_rng();
+ let keypair = Keypair::generate(&mut rng);
+ let log_base = keypair.public().as_element();
+ let admissible_values = [RistrettoPoint::identity(), Ristretto::generator()];
+
+ let value = RistrettoPoint::identity();
+ let ext_ciphertext = ExtendedCiphertext::new(value, keypair.public(), &mut rng);
+ let ciphertext = ext_ciphertext.inner;
+
+ let mut transcript = Transcript::new(b"test_ring_encryption");
+ RingProof::initialize_transcript(&mut transcript, keypair.public());
+
+ let mut ring_responses = vec![Scalar25519::default(); 2];
+ let signature_ring = Ring::new(
+ 0,
+ log_base,
+ ext_ciphertext,
+ &admissible_values,
+ 0,
+ &transcript,
+ &mut ring_responses,
+ &mut rng,
+ );
+ let common_challenge =
+ Ring::aggregate(vec![signature_ring], log_base, &mut transcript, &mut rng);
+
+ RingProof::new(common_challenge, ring_responses)
+ .verify(
+ keypair.public(),
+ iter::once(&admissible_values as &[_]),
+ iter::once(ciphertext),
+ &mut Transcript::new(b"test_ring_encryption"),
+ )
+ .unwrap();
+
+ let value = Ristretto::generator();
+ let ext_ciphertext = ExtendedCiphertext::new(value, keypair.public(), &mut rng);
+ let ciphertext = ext_ciphertext.inner;
+
+ let mut transcript = Transcript::new(b"test_ring_encryption");
+ RingProof::initialize_transcript(&mut transcript, keypair.public());
+ let mut ring_responses = vec![Scalar25519::default(); 2];
+ let signature_ring = Ring::new(
+ 0,
+ log_base,
+ ext_ciphertext,
+ &admissible_values,
+ 1,
+ &transcript,
+ &mut ring_responses,
+ &mut rng,
+ );
+ let common_challenge =
+ Ring::aggregate(vec![signature_ring], log_base, &mut transcript, &mut rng);
+
+ RingProof::new(common_challenge, ring_responses)
+ .verify(
+ keypair.public(),
+ iter::once(&admissible_values as &[_]),
+ iter::once(ciphertext),
+ &mut Transcript::new(b"test_ring_encryption"),
+ )
+ .unwrap();
+ }
+
+ #[test]
+ fn single_ring_with_4_elements_works() {
+ let mut rng = thread_rng();
+ let keypair = Keypair::generate(&mut rng);
+ let log_base = keypair.public().as_element();
+ let admissible_values: Vec<_> = (0_u32..4)
+ .map(|i| Ristretto::mul_generator(&Scalar25519::from(i)))
+ .collect();
+
+ for _ in 0..100 {
+ let val: u32 = rng.gen_range(0..4);
+ let element_val = Ristretto::mul_generator(&Scalar25519::from(val));
+ let ext_ciphertext = ExtendedCiphertext::new(element_val, keypair.public(), &mut rng);
+ let ciphertext = ext_ciphertext.inner;
+
+ let mut transcript = Transcript::new(b"test_ring_encryption");
+ RingProof::initialize_transcript(&mut transcript, keypair.public());
+
+ let mut ring_responses = vec![Scalar25519::default(); 4];
+ let signature_ring = Ring::new(
+ 0,
+ log_base,
+ ext_ciphertext,
+ &admissible_values,
+ val as usize,
+ &transcript,
+ &mut ring_responses,
+ &mut rng,
+ );
+ let common_challenge =
+ Ring::aggregate(vec![signature_ring], log_base, &mut transcript, &mut rng);
+
+ RingProof::new(common_challenge, ring_responses)
+ .verify(
+ keypair.public(),
+ iter::once(admissible_values.as_slice()),
+ iter::once(ciphertext),
+ &mut Transcript::new(b"test_ring_encryption"),
+ )
+ .unwrap();
+ }
+ }
+
+ #[test_casing(5, 3..=7)]
+ fn multiple_rings_with_boolean_flags_work(ring_count: usize) {
+ let mut rng = thread_rng();
+ let keypair = Keypair::generate(&mut rng);
+ let log_base = keypair.public().as_element();
+ let admissible_values = [RistrettoPoint::identity(), Ristretto::generator()];
+
+ for _ in 0..20 {
+ let mut transcript = Transcript::new(b"test_ring_encryption");
+ RingProof::initialize_transcript(&mut transcript, keypair.public());
+
+ let mut ring_responses = vec![Scalar25519::default(); ring_count * 2];
+
+ let (ciphertexts, rings): (Vec<_>, Vec<_>) = ring_responses
+ .chunks_mut(2)
+ .enumerate()
+ .map(|(ring_index, ring_responses)| {
+ let val: u32 = rng.gen_range(0..=1);
+ let element_val = Ristretto::mul_generator(&Scalar25519::from(val));
+ let ext_ciphertext =
+ ExtendedCiphertext::new(element_val, keypair.public(), &mut rng);
+ let ciphertext = ext_ciphertext.inner;
+
+ let signature_ring = Ring::new(
+ ring_index,
+ log_base,
+ ext_ciphertext,
+ &admissible_values,
+ val as usize,
+ &transcript,
+ ring_responses,
+ &mut rng,
+ );
+
+ (ciphertext, signature_ring)
+ })
+ .unzip();
+
+ let common_challenge = Ring::aggregate(rings, log_base, &mut transcript, &mut rng);
+
+ RingProof::new(common_challenge, ring_responses)
+ .verify(
+ keypair.public(),
+ iter::repeat(&admissible_values as &[_]).take(ring_count),
+ ciphertexts.into_iter(),
+ &mut Transcript::new(b"test_ring_encryption"),
+ )
+ .unwrap();
+ }
+ }
+
+ #[test]
+ fn multiple_rings_with_base4_value_encoding_work() {
+ const RING_COUNT: u8 = 4;
+
+ let admissible_values: Vec<_> = (0..RING_COUNT)
+ .map(|ring_index| {
+ let power: u32 = 1 << (2 * u32::from(ring_index));
+ [
+ RistrettoPoint::identity(),
+ Ristretto::mul_generator(&Scalar25519::from(power)),
+ Ristretto::mul_generator(&Scalar25519::from(power * 2)),
+ Ristretto::mul_generator(&Scalar25519::from(power * 3)),
+ ]
+ })
+ .collect();
+
+ let mut rng = thread_rng();
+ let keypair = Keypair::generate(&mut rng);
+ let log_base = keypair.public().as_element();
+
+ for _ in 0..20 {
+ let overall_value: u8 = rng.gen();
+ let mut transcript = Transcript::new(b"test_ring_encryption");
+ RingProof::initialize_transcript(&mut transcript, keypair.public());
+
+ let mut ring_responses = vec![Scalar25519::default(); RING_COUNT as usize * 4];
+
+ let (ciphertexts, rings): (Vec<_>, Vec<_>) = ring_responses
+ .chunks_mut(4)
+ .enumerate()
+ .map(|(ring_index, ring_responses)| {
+ let mask = 3 << (2 * ring_index);
+ let val = overall_value & mask;
+ let val_index = (val >> (2 * ring_index)) as usize;
+ assert!(val_index < 4);
+
+ let element_val = Ristretto::mul_generator(&Scalar25519::from(val));
+ let ext_ciphertext =
+ ExtendedCiphertext::new(element_val, keypair.public(), &mut rng);
+ let ciphertext = ext_ciphertext.inner;
+
+ let signature_ring = Ring::new(
+ ring_index,
+ log_base,
+ ext_ciphertext,
+ &admissible_values[ring_index],
+ val_index,
+ &transcript,
+ ring_responses,
+ &mut rng,
+ );
+
+ (ciphertext, signature_ring)
+ })
+ .unzip();
+
+ let common_challenge = Ring::aggregate(rings, log_base, &mut transcript, &mut rng);
+ let admissible_values = admissible_values.iter().map(|values| values as &[_]);
+
+ RingProof::new(common_challenge, ring_responses)
+ .verify(
+ keypair.public(),
+ admissible_values,
+ ciphertexts.into_iter(),
+ &mut Transcript::new(b"test_ring_encryption"),
+ )
+ .unwrap();
+ }
+ }
+
+ #[test_casing(5, 3..=7)]
+ #[allow(clippy::needless_collect)]
+ fn proof_builder_works(ring_count: usize) {
+ let mut rng = thread_rng();
+ let keypair = Keypair::generate(&mut rng);
+ let mut transcript = Transcript::new(b"test_ring_encryption");
+ let admissible_values = [RistrettoPoint::identity(), Ristretto::generator()];
+ let mut ring_responses = vec![Scalar25519::default(); ring_count * 2];
+
+ let mut builder = RingProofBuilder::new(
+ keypair.public(),
+ ring_count,
+ &mut ring_responses,
+ &mut transcript,
+ &mut rng,
+ );
+ let ciphertexts: Vec<_> = (0..ring_count)
+ .map(|i| builder.add_value(&admissible_values, i & 1).inner)
+ .collect();
+
+ RingProof::new(builder.build(), ring_responses)
+ .verify(
+ keypair.public(),
+ iter::repeat(&admissible_values as &[_]).take(ring_count),
+ ciphertexts.into_iter(),
+ &mut Transcript::new(b"test_ring_encryption"),
+ )
+ .unwrap();
+ }
+}
+