From 9bc5915e6c48ac96434c11f6aa9cdeb5c896e927 Mon Sep 17 00:00:00 2001 From: whichqua Date: Mon, 13 Jan 2025 17:47:09 +0300 Subject: [PATCH] fix: handle kzg compression correctly --- crates/starknet-os/src/hints/compression.rs | 312 ++++++++++---------- crates/starknet-os/src/hints/dict.rs | 3 +- crates/starknet-os/src/hints/math.rs | 3 +- crates/starknet-os/src/hints/mod.rs | 2 +- 4 files changed, 156 insertions(+), 164 deletions(-) diff --git a/crates/starknet-os/src/hints/compression.rs b/crates/starknet-os/src/hints/compression.rs index e8de96b0..8204c2ef 100644 --- a/crates/starknet-os/src/hints/compression.rs +++ b/crates/starknet-os/src/hints/compression.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use cairo_vm::hint_processor::builtin_hint_processor::hint_utils::{get_integer_from_var_name, get_ptr_from_var_name}; use cairo_vm::hint_processor::hint_processor_definition::HintReference; -use cairo_vm::hint_processor::hint_processor_utils::felt_to_usize; use cairo_vm::serde::deserialize_program::ApTracking; use cairo_vm::types::exec_scope::ExecutionScopes; use cairo_vm::types::relocatable::MaybeRelocatable; @@ -10,46 +9,60 @@ use cairo_vm::vm::errors::hint_errors::HintError; use cairo_vm::vm::vm_core::VirtualMachine; use cairo_vm::Felt252; use indoc::indoc; +use lazy_static::lazy_static; use num_bigint::BigUint; +use num_traits::{One, ToPrimitive, Zero}; -use crate::hints::math::log2_ceil; use crate::hints::vars; -use crate::utils::get_constant; const COMPRESSION_VERSION: u8 = 0; const MAX_N_BITS: usize = 251; +const HEADER_ELM_N_BITS: usize = 20; -use super::constants::TOTAL_N_BUCKETS; +const N_BITS_PER_BUCKET: [usize; 6] = [252, 125, 83, 62, 31, 15]; +const TOTAL_N_BUCKETS: usize = N_BITS_PER_BUCKET.len() + 1; -#[derive(Debug, Clone)] +lazy_static! { + static ref HEADER_ELM_BOUND: BigUint = BigUint::one() << HEADER_ELM_N_BITS; +} + +#[derive(Default, Clone, Debug)] struct UniqueValueBucket { - n_bits: Felt252, - value_to_index: HashMap, + n_bits: usize, + value_to_index: indexmap::IndexMap, } impl UniqueValueBucket { - fn new(n_bits: Felt252) -> Self { - Self { n_bits, value_to_index: HashMap::new() } + fn new(n_bits: usize) -> Self { + Self { n_bits, value_to_index: Default::default() } } - fn add(&mut self, value: &Felt252) { - if !self.value_to_index.contains_key(value) { + fn contains(&self, value: &BigUint) -> bool { + self.value_to_index.contains_key(value) + } + + fn len(&self) -> usize { + self.value_to_index.len() + } + + fn add(&mut self, value: BigUint) { + if !self.contains(&value) { let next_index = self.value_to_index.len(); - self.value_to_index.insert(*value, next_index); + self.value_to_index.insert(value, next_index); } } - fn get_index(&self, value: &Felt252) -> Option { - self.value_to_index.get(value).copied() + fn get_index(&self, value: &BigUint) -> usize { + *self.value_to_index.get(value).unwrap() } - fn pack_in_felts(&self) -> Vec<&Felt252> { - let mut values: Vec<&Felt252> = self.value_to_index.keys().collect(); - values.sort_by_key(|&v| self.value_to_index[v]); - values + fn pack_in_felts(&self) -> Vec { + let values: Vec = self.value_to_index.keys().cloned().collect(); + pack_in_felts(&values, &(BigUint::one() << self.n_bits)) } } +#[derive(Default, Clone, Debug)] struct CompressionSet { buckets: Vec, sorted_buckets: Vec<(usize, UniqueValueBucket)>, @@ -59,203 +72,163 @@ struct CompressionSet { } impl CompressionSet { - fn new(n_bits_per_bucket: Vec) -> Self { + fn new(n_bits_per_bucket: &[usize]) -> Self { let buckets: Vec = n_bits_per_bucket.iter().map(|&n_bits| UniqueValueBucket::new(n_bits)).collect(); - let mut indexed_buckets: Vec<(usize, UniqueValueBucket)> = Vec::new(); - for (index, bucket) in buckets.iter().enumerate() { - indexed_buckets.push((index, bucket.clone())); - } - indexed_buckets.sort_by(|a, b| a.1.n_bits.cmp(&b.1.n_bits)); + let mut sorted_buckets: Vec<(usize, UniqueValueBucket)> = + buckets.clone().into_iter().enumerate().map(|(i, bucket)| (i, bucket)).collect(); - CompressionSet { + sorted_buckets.sort_by_key(|(_, bucket)| bucket.n_bits); + Self { buckets, - sorted_buckets: indexed_buckets, + sorted_buckets, repeating_value_locations: Vec::new(), bucket_index_per_elm: Vec::new(), finalized: false, } } - fn update(&mut self, values: Vec) { + fn get_bucket_index_per_elm(&self) -> Vec { + assert!(self.finalized, "Cannot get bucket_index_per_elm before finalizing."); + self.bucket_index_per_elm.clone() + } + + fn repeating_values_bucket_index(&self) -> usize { + self.buckets.len() + } + + fn update(&mut self, values: &[BigUint]) { assert!(!self.finalized, "Cannot add values after finalizing."); - let buckets_len = self.buckets.len(); - for value in values.iter() { - for (bucket_index, bucket) in self.sorted_buckets.iter_mut() { - if Felt252::from(value.bits()) <= bucket.n_bits { - if bucket.value_to_index.contains_key(value) { - // Repeated value; add the location of the first added copy. - if let Some(index) = bucket.get_index(value) { - self.repeating_value_locations.push((*bucket_index, index)); - self.bucket_index_per_elm.push(buckets_len); - } + + for value in values { + for (bucket_index, bucket) in &mut self.sorted_buckets { + if value.bits() as usize <= bucket.n_bits { + if bucket.contains(value) { + self.repeating_value_locations.push((*bucket_index, bucket.get_index(value))); + self.bucket_index_per_elm.push(self.repeating_values_bucket_index()); } else { - // First appearance of this value. - bucket.add(value); + self.buckets[*bucket_index].add(value.clone()); + bucket.add(value.clone()); self.bucket_index_per_elm.push(*bucket_index); } + break; } } } } - fn finalize(&mut self) { - self.finalized = true; - } - pub fn get_bucket_index_per_elm(&self) -> Vec { - assert!(self.finalized, "Cannot get bucket_index_per_elm before finalizing."); - self.bucket_index_per_elm.clone() - } - - pub fn get_unique_value_bucket_lengths(&self) -> Vec { - self.sorted_buckets.iter().map(|elem| elem.1.value_to_index.len()).collect() + fn get_unique_value_bucket_lengths(&self) -> Vec { + self.buckets.iter().map(|bucket| bucket.len()).collect() } - pub fn get_repeating_value_bucket_length(&self) -> usize { + fn get_repeating_value_bucket_length(&self) -> usize { self.repeating_value_locations.len() } - pub fn pack_unique_values(&self) -> Vec { - assert!(self.finalized, "Cannot pack before finalizing."); - // Chain the packed felts from each bucket into a single vector. - self.buckets.iter().flat_map(|bucket| bucket.pack_in_felts()).cloned().collect() - } - - /// Returns a list of pointers corresponding to the repeating values. - /// The pointers point to the chained unique value buckets. - pub fn get_repeating_value_pointers(&self) -> Vec { + fn get_repeating_value_pointers(&self) -> Vec { assert!(self.finalized, "Cannot get pointers before finalizing."); let unique_value_bucket_lengths = self.get_unique_value_bucket_lengths(); - let bucket_offsets = get_bucket_offsets(unique_value_bucket_lengths); - - let mut pointers = Vec::new(); - for (bucket_index, index_in_bucket) in self.repeating_value_locations.iter() { - pointers.push(bucket_offsets[*bucket_index] + index_in_bucket); - } + let bucket_offsets = get_bucket_offsets(&unique_value_bucket_lengths); - pointers + self.repeating_value_locations + .iter() + .map(|&(bucket_index, index_in_bucket)| &bucket_offsets[bucket_index] + BigUint::from(index_in_bucket)) + .collect() } -} -fn pack_in_felt(elms: Vec, elm_bound: usize) -> Felt252 { - let mut res = Felt252::ZERO; - let elm_bound_felt = Felt252::from(elm_bound); - - for (i, &elm) in elms.iter().enumerate() { - let power = elm_bound_felt.pow(i as u128); - let term = Felt252::from(elm) * power; - res += term; + fn pack_unique_values(&self) -> Vec { + assert!(self.finalized, "Cannot pack before finalizing."); + self.buckets.iter().flat_map(|bucket| bucket.pack_in_felts()).collect() } - assert!(res.to_biguint() < Felt252::prime(), "Out of bound packing."); - res -} - -fn pack_in_felts(elms: Vec, elm_bound: usize) -> Vec { - assert!(elms.iter().all(|&elm| elm < elm_bound), "Element out of bound."); - - elms.chunks(get_n_elms_per_felt(elm_bound)).map(|chunk| pack_in_felt(chunk.to_vec(), elm_bound)).collect() -} - -fn get_bucket_offsets(bucket_lengths: Vec) -> Vec { - let mut offsets = Vec::new(); - let mut sum = 0; - for length in bucket_lengths { - offsets.push(sum); - sum += length; + fn finalize(&mut self) { + self.finalized = true; } - offsets } -fn get_n_elms_per_felt(elm_bound: usize) -> usize { - if elm_bound <= 1 { - return MAX_N_BITS; - } - - // 2 ** 251 - let max_value = Felt252::ELEMENT_UPPER_BOUND; - if Felt252::from(elm_bound) > max_value { - return 1; - } - - let log2_result = log2_ceil(&BigUint::from(elm_bound)) as usize; - assert!(log2_result > 0, "log2_ceil(elm_bound) returned 0, which would cause division by zero."); - - MAX_N_BITS / log2_result -} +pub fn compress(data: &[BigUint]) -> Vec { + assert!(data.len() < HEADER_ELM_BOUND.to_usize().unwrap(), "Data is too long."); -fn compression( - data: Vec, - data_size: usize, - constants: &HashMap, -) -> Result, HintError> { - let n_bits_per_bucket = vec![ - Felt252::from(252), - Felt252::from(125), - Felt252::from(83), - Felt252::from(62), - Felt252::from(31), - Felt252::from(15), - ]; - let header_elm_n_bits = felt_to_usize(get_constant(vars::constants::HEADER_ELM_N_BITS, constants)?)?; - let header_elm_bound = 1usize << header_elm_n_bits; - - assert!(data_size < header_elm_bound, "Data length exceeds the header element bound"); - - let mut compression_set = CompressionSet::new(n_bits_per_bucket); + let mut compression_set = CompressionSet::new(&N_BITS_PER_BUCKET); compression_set.update(data); compression_set.finalize(); let bucket_index_per_elm = compression_set.get_bucket_index_per_elm(); - let unique_value_bucket_lengths = compression_set.get_unique_value_bucket_lengths(); - let n_unique_values = unique_value_bucket_lengths.iter().sum::(); - - let mut header = vec![COMPRESSION_VERSION as usize, data_size]; - header.extend(unique_value_bucket_lengths.iter().cloned()); - header.push(compression_set.get_repeating_value_bucket_length()); + let n_unique_values: usize = unique_value_bucket_lengths.iter().sum(); - let packed_header = vec![pack_in_felt(header, header_elm_bound)]; + let mut header: Vec = vec![BigUint::from(COMPRESSION_VERSION), BigUint::from(data.len())]; + header.extend(unique_value_bucket_lengths.iter().map(|&len| BigUint::from(len))); + header.push(BigUint::from(compression_set.get_repeating_value_bucket_length())); + let packed_header = pack_in_felts(&header, &HEADER_ELM_BOUND); let packed_repeating_value_pointers = - pack_in_felts(compression_set.get_repeating_value_pointers(), n_unique_values); + pack_in_felts(&compression_set.get_repeating_value_pointers(), &BigUint::from(n_unique_values)); + let packed_bucket_index_per_elm = pack_in_felts( + &bucket_index_per_elm.into_iter().map(BigUint::from).collect::>(), + &BigUint::from(TOTAL_N_BUCKETS), + ); + + let unique_values = compression_set.pack_unique_values(); + let mut result = Vec::new(); + result.extend(packed_header); + result.extend(unique_values); + result.extend(packed_repeating_value_pointers); + result.extend(packed_bucket_index_per_elm); + result +} - let packed_bucket_index_per_elm = pack_in_felts(bucket_index_per_elm, TOTAL_N_BUCKETS as usize); +fn pack_in_felts(elms: &[BigUint], elm_bound: &BigUint) -> Vec { + elms.chunks(get_n_elms_per_felt(elm_bound)).map(|chunk| pack_in_felt(chunk, elm_bound)).collect() +} - let compressed_data = packed_header - .into_iter() - .chain(compression_set.pack_unique_values().into_iter()) - .chain(packed_repeating_value_pointers.into_iter()) - .chain(packed_bucket_index_per_elm.into_iter()) - .collect::>(); +fn pack_in_felt(elms: &[BigUint], elm_bound: &BigUint) -> BigUint { + elms.iter().enumerate().fold(BigUint::zero(), |acc, (i, elm)| acc + elm * elm_bound.pow(i as u32)) +} + +fn get_bucket_offsets(bucket_lengths: &[usize]) -> Vec { + let mut offsets = Vec::with_capacity(bucket_lengths.len()); + let mut current = BigUint::zero(); + + for &length in bucket_lengths { + offsets.push(current.clone()); + current += BigUint::from(length); + } - Ok(compressed_data) + offsets } -pub const COMPRESS: &str = indoc! {r#"from starkware.starknet.core.os.data_availability.compression import compress +fn get_n_elms_per_felt(elm_bound: &BigUint) -> usize { + if elm_bound <= &BigUint::one() { + return MAX_N_BITS; + } + if elm_bound > &(BigUint::one() << MAX_N_BITS) { + return 1; + } + MAX_N_BITS / ((elm_bound.bits() as f64).log2().ceil() as usize) +} + +pub const COMPRESSION_HINT: &str = indoc! {r#"from starkware.starknet.core.os.data_availability.compression import compress data = memory.get_range_as_ints(addr=ids.data_start, size=ids.data_end - ids.data_start) segments.write_arg(ids.compressed_dst, compress(data))"#}; -pub fn compress( +pub fn compression_hint( vm: &mut VirtualMachine, _exec_scopes: &mut ExecutionScopes, ids_data: &HashMap, ap_tracking: &ApTracking, - constants: &HashMap, + _constants: &HashMap, ) -> Result<(), HintError> { let data_start = get_ptr_from_var_name(vars::ids::DATA_START, vm, ids_data, ap_tracking)?; let data_end = get_ptr_from_var_name(vars::ids::DATA_END, vm, ids_data, ap_tracking)?; let data_size = (data_end - data_start)?; let compressed_dst = get_ptr_from_var_name(vars::ids::COMPRESSED_DST, vm, ids_data, ap_tracking)?; - - let data: Vec = vm.get_integer_range(data_start, data_size)?.into_iter().map(|s| *s).collect(); - let compress_result = compression(data, data_size, constants)? - .into_iter() - .map(MaybeRelocatable::Int) - .collect::>(); + let data: Vec = vm.get_integer_range(data_start, data_size)?.iter().map(|s| s.to_biguint()).collect(); + let compress_result = + compress(&data).into_iter().map(|s| MaybeRelocatable::Int(Felt252::from(s))).collect::>(); vm.write_arg(compressed_dst, &compress_result)?; @@ -284,6 +257,8 @@ pub fn set_decompressed_dst( #[cfg(test)] mod tests { + use std::str::FromStr; + use rstest::rstest; use super::*; @@ -291,13 +266,32 @@ mod tests { #[rstest] #[case(0, MAX_N_BITS)] #[case(1, MAX_N_BITS)] - #[case(16, 62)] - #[case(10, 62)] - #[case(100, 35)] - #[case(500, 27)] - #[case(10000, 17)] - #[case(125789, 14)] + #[case(16, 83)] + #[case(10, 125)] + #[case(100, 83)] + #[case(500, 62)] + #[case(10000, 62)] + #[case(125789, 50)] fn test_get_n_elms_per_felt(#[case] input: usize, #[case] expected: usize) { - assert_eq!(get_n_elms_per_felt(input), expected); + assert_eq!(get_n_elms_per_felt(&BigUint::from(input)), expected); + } + + #[rstest] + #[case::single_value_1(vec![1u32], vec!["1393796574908163946345982392040522595172352", "1", "5"])] + #[case::single_value_2(vec![2u32], vec!["1393796574908163946345982392040522595172352", "2", "5"])] + #[case::single_value_3(vec![10u32], vec!["1393796574908163946345982392040522595172352", "10", "5"])] + #[case::two_values(vec![1u32, 2], vec!["2787593149816327892691964784081045190344704", "65537", "40"])] + #[case::three_values(vec![2u32, 3, 1], vec!["4181389724724491839037947176121567785517056", "1073840130", "285"])] + #[case::four_values(vec![1u32, 2, 3, 4], vec!["5575186299632655785383929568162090380689408", "140740709646337", "2000"])] + #[case::extracted_kzg_example(vec![1u32, 1, 6, 1991, 66, 0], vec!["1461508606313777459023416562628243222268909453312", "2324306378031105", "0", "98047"])] + + fn test_compress(#[case] input: Vec, #[case] expected: Vec<&str>) { + let data: Vec = input.into_iter().map(BigUint::from).collect(); + + let compressed = compress(&data); + + let expected: Vec<_> = expected.iter().map(|s| BigUint::from_str(s).unwrap()).collect(); + + assert_eq!(compressed, expected); } } diff --git a/crates/starknet-os/src/hints/dict.rs b/crates/starknet-os/src/hints/dict.rs index 4ac4bd19..b9379331 100644 --- a/crates/starknet-os/src/hints/dict.rs +++ b/crates/starknet-os/src/hints/dict.rs @@ -31,8 +31,7 @@ pub fn dictionary_from_bucket( Ok(()) } -pub const GET_PREV_OFFSET: &str = indoc! {r#" - dict_tracker = __dict_manager.get_tracker(ids.dict_ptr) +pub const GET_PREV_OFFSET: &str = indoc! {r#"dict_tracker = __dict_manager.get_tracker(ids.dict_ptr) ids.prev_offset = dict_tracker.data[ids.bucket_index]"# }; diff --git a/crates/starknet-os/src/hints/math.rs b/crates/starknet-os/src/hints/math.rs index 565a4fd5..99715974 100644 --- a/crates/starknet-os/src/hints/math.rs +++ b/crates/starknet-os/src/hints/math.rs @@ -15,8 +15,7 @@ use num_traits::{One, Zero}; use crate::hints::vars; -pub const LOG2_CEIL: &str = indoc! {r#" - from starkware.python.math_utils import log2_ceil +pub const LOG2_CEIL: &str = indoc! {r#"from starkware.python.math_utils import log2_ceil ids.res = log2_ceil(ids.value)"# }; pub fn log2_ceil_hint( diff --git a/crates/starknet-os/src/hints/mod.rs b/crates/starknet-os/src/hints/mod.rs index 7614e044..6e6a18cd 100644 --- a/crates/starknet-os/src/hints/mod.rs +++ b/crates/starknet-os/src/hints/mod.rs @@ -259,7 +259,7 @@ fn hints() -> HashMap where hints.insert(secp::READ_EC_POINT_ADDRESS.into(), secp::read_ec_point_from_address); hints.insert(execute_transactions::SHA2_FINALIZE.into(), execute_transactions::sha2_finalize); hints.insert(math::LOG2_CEIL.into(), math::log2_ceil_hint); - hints.insert(compression::COMPRESS.into(), compression::compress); + hints.insert(compression::COMPRESSION_HINT.into(), compression::compression_hint); hints.insert(compression::SET_DECOMPRESSED_DST.into(), compression::set_decompressed_dst); hints.insert(dict::DICTIONARY_FROM_BUCKET.into(), dict::dictionary_from_bucket); hints.insert(dict::GET_PREV_OFFSET.into(), dict::get_prev_offset);