From d7f698ec98b924b4cb90e4cd6028fc00b28ed2dc Mon Sep 17 00:00:00 2001 From: ftheirs Date: Fri, 20 Dec 2024 10:40:58 -0300 Subject: [PATCH] compression hint --- crates/starknet-os/src/hints/compression.rs | 255 ++++++++++++++++++++ crates/starknet-os/src/hints/mod.rs | 2 + crates/starknet-os/src/hints/output.rs | 3 +- crates/starknet-os/src/hints/vars.rs | 4 + 4 files changed, 262 insertions(+), 2 deletions(-) create mode 100644 crates/starknet-os/src/hints/compression.rs diff --git a/crates/starknet-os/src/hints/compression.rs b/crates/starknet-os/src/hints/compression.rs new file mode 100644 index 00000000..133a0be9 --- /dev/null +++ b/crates/starknet-os/src/hints/compression.rs @@ -0,0 +1,255 @@ +use std::collections::HashMap; + +use cairo_vm::hint_processor::builtin_hint_processor::hint_utils::get_ptr_from_var_name; +use cairo_vm::hint_processor::hint_processor_definition::HintReference; +use cairo_vm::hint_processor::hint_processor_utils::felt_to_usize; +use cairo_vm::serde::deserialize_program::ApTracking; +use cairo_vm::types::exec_scope::ExecutionScopes; +use cairo_vm::types::relocatable::MaybeRelocatable; +use cairo_vm::vm::errors::hint_errors::HintError; +use cairo_vm::vm::vm_core::VirtualMachine; +use cairo_vm::Felt252; +use indoc::indoc; + +use crate::hints::vars; +use crate::utils::get_constant; + +const COMPRESSION_VERSION: u8 = 0; +const MAX_N_BITS: usize = 251; +const N_UNIQUE_VALUE_BUCKETS: usize = 6; +const TOTAL_N_BUCKETS: usize = N_UNIQUE_VALUE_BUCKETS + 1; + +#[derive(Debug, Clone)] +struct UniqueValueBucket { + n_bits: Felt252, + value_to_index: HashMap, +} + +impl UniqueValueBucket { + fn new(n_bits: Felt252) -> Self { + Self { n_bits, value_to_index: HashMap::new() } + } + + fn add(&mut self, value: &Felt252) { + if !self.value_to_index.contains_key(value) { + let next_index = self.value_to_index.len(); + self.value_to_index.insert(*value, next_index); + } + } + + fn get_index(&self, value: &Felt252) -> Option { + self.value_to_index.get(value).copied() + } + + fn pack_in_felts(&self) -> Vec<&Felt252> { + let mut values: Vec<&Felt252> = self.value_to_index.keys().collect(); + values.sort_by_key(|&v| self.value_to_index[v]); + values + } +} + +struct CompressionSet { + buckets: Vec, + sorted_buckets: Vec<(usize, UniqueValueBucket)>, + repeating_value_locations: Vec<(usize, usize)>, + bucket_index_per_elm: Vec, + finalized: bool, +} + +impl CompressionSet { + fn new(n_bits_per_bucket: Vec) -> Self { + let buckets: Vec = + n_bits_per_bucket.iter().map(|&n_bits| UniqueValueBucket::new(n_bits)).collect(); + + let mut indexed_buckets: Vec<(usize, UniqueValueBucket)> = Vec::new(); + for (index, bucket) in buckets.iter().enumerate() { + indexed_buckets.push((index, bucket.clone())); + } + indexed_buckets.sort_by(|a, b| a.1.n_bits.cmp(&b.1.n_bits)); + + CompressionSet { + buckets, + sorted_buckets: indexed_buckets, + repeating_value_locations: Vec::new(), + bucket_index_per_elm: Vec::new(), + finalized: false, + } + } + + fn update(&mut self, values: Vec) { + assert!(!self.finalized, "Cannot add values after finalizing."); + let buckets_len = self.buckets.len(); + for value in values.iter() { + for (bucket_index, bucket) in self.sorted_buckets.iter_mut() { + if Felt252::from(value.bits()) <= bucket.n_bits { + if bucket.value_to_index.contains_key(value) { + // Repeated value; add the location of the first added copy. + if let Some(index) = bucket.get_index(value) { + self.repeating_value_locations.push((*bucket_index, index)); + self.bucket_index_per_elm.push(buckets_len); + } + } else { + // First appearance of this value. + bucket.add(value); + self.bucket_index_per_elm.push(*bucket_index); + } + } + } + } + } + + fn finalize(&mut self) { + self.finalized = true; + } + pub fn get_bucket_index_per_elm(&self) -> Vec { + assert!(self.finalized, "Cannot get bucket_index_per_elm before finalizing."); + self.bucket_index_per_elm.clone() + } + + pub fn get_unique_value_bucket_lengths(&self) -> Vec { + self.sorted_buckets.iter().map(|elem| elem.1.value_to_index.len()).collect() + } + + pub fn get_repeating_value_bucket_length(&self) -> usize { + self.repeating_value_locations.len() + } + + pub fn pack_unique_values(&self) -> Vec { + assert!(self.finalized, "Cannot pack before finalizing."); + // Chain the packed felts from each bucket into a single vector. + self.buckets.iter().flat_map(|bucket| bucket.pack_in_felts()).cloned().collect() + } + + /// Returns a list of pointers corresponding to the repeating values. + /// The pointers point to the chained unique value buckets. + pub fn get_repeating_value_pointers(&self) -> Vec { + assert!(self.finalized, "Cannot get pointers before finalizing."); + + let unique_value_bucket_lengths = self.get_unique_value_bucket_lengths(); + let bucket_offsets = get_bucket_offsets(unique_value_bucket_lengths); + + let mut pointers = Vec::new(); + for (bucket_index, index_in_bucket) in self.repeating_value_locations.iter() { + pointers.push(bucket_offsets[*bucket_index] + index_in_bucket); + } + + pointers + } +} + +fn pack_in_felt(elms: Vec, elm_bound: usize) -> Felt252 { + let mut res = Felt252::ZERO; + for (i, &elm) in elms.iter().enumerate() { + res += Felt252::from(elm * elm_bound.pow(i as u32)); + } + assert!(res.to_biguint() < Felt252::prime(), "Out of bound packing."); + res +} + +fn pack_in_felts(elms: Vec, elm_bound: usize) -> Vec { + assert!(elms.iter().all(|&elm| elm < elm_bound), "Element out of bound."); + + elms.chunks(get_n_elms_per_felt(elm_bound)).map(|chunk| pack_in_felt(chunk.to_vec(), elm_bound)).collect() +} + +fn get_bucket_offsets(bucket_lengths: Vec) -> Vec { + let mut offsets = Vec::new(); + let mut sum = 0; + for length in bucket_lengths { + offsets.push(sum); + sum += length; + } + offsets +} + +fn log2_ceil(x: usize) -> usize { + assert!(x > 0); + (x - 1).count_ones() as usize +} + +fn get_n_elms_per_felt(elm_bound: usize) -> usize { + if elm_bound <= 1 { + return MAX_N_BITS; + } + if elm_bound > 2_usize.pow(MAX_N_BITS as u32) { + return 1; + } + + MAX_N_BITS / log2_ceil(elm_bound) +} + +fn compression( + data: Vec, + data_size: usize, + constants: &HashMap, +) -> Result, HintError> { + let n_bits_per_bucket = vec![ + Felt252::from(252), + Felt252::from(125), + Felt252::from(83), + Felt252::from(62), + Felt252::from(31), + Felt252::from(15), + ]; + let header_elm_n_bits = felt_to_usize(get_constant(vars::constants::HEADER_ELM_N_BITS, constants)?)?; + let header_elm_bound = 1usize << header_elm_n_bits; + + assert!(data_size < header_elm_bound, "Data length exceeds the header element bound"); + + let mut compression_set = CompressionSet::new(n_bits_per_bucket); + compression_set.update(data); + compression_set.finalize(); + + let bucket_index_per_elm = compression_set.get_bucket_index_per_elm(); + + let unique_value_bucket_lengths = compression_set.get_unique_value_bucket_lengths(); + let n_unique_values = unique_value_bucket_lengths.iter().sum::(); + + let mut header = vec![COMPRESSION_VERSION as usize, data_size]; + header.extend(unique_value_bucket_lengths.iter().cloned()); + header.push(compression_set.get_repeating_value_bucket_length()); + + let packed_header = vec![pack_in_felt(header, header_elm_bound)]; + + let packed_repeating_value_pointers = + pack_in_felts(compression_set.get_repeating_value_pointers(), n_unique_values); + + let packed_bucket_index_per_elm = pack_in_felts(bucket_index_per_elm, TOTAL_N_BUCKETS); + + let compressed_data = packed_header + .into_iter() + .chain(compression_set.pack_unique_values().into_iter()) + .chain(packed_repeating_value_pointers.into_iter()) + .chain(packed_bucket_index_per_elm.into_iter()) + .collect::>(); + + Ok(compressed_data) +} + +pub const COMPRESS: &str = indoc! {r#"from starkware.starknet.core.os.data_availability.compression import compress + data = memory.get_range_as_ints(addr=ids.data_start, size=ids.data_end - ids.data_start) + segments.write_arg(ids.compressed_dst, compress(data))"#}; + +pub fn compress( + vm: &mut VirtualMachine, + _exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + constants: &HashMap, +) -> Result<(), HintError> { + let data_start = get_ptr_from_var_name(vars::ids::DATA_START, vm, ids_data, ap_tracking)?; + let data_end = get_ptr_from_var_name(vars::ids::DATA_END, vm, ids_data, ap_tracking)?; + let data_size = (data_end - data_start)?; + + let compressed_dst = get_ptr_from_var_name(vars::ids::COMPRESSED_DST, vm, ids_data, ap_tracking)?; + + let data: Vec = vm.get_integer_range(data_start, data_size)?.into_iter().map(|s| *s).collect(); + let compress_result = compression(data, data_size, constants)? + .into_iter() + .map(MaybeRelocatable::Int) + .collect::>(); + + vm.write_arg(compressed_dst, &compress_result)?; + + Ok(()) +} diff --git a/crates/starknet-os/src/hints/mod.rs b/crates/starknet-os/src/hints/mod.rs index 5452e0c1..4ceedcdb 100644 --- a/crates/starknet-os/src/hints/mod.rs +++ b/crates/starknet-os/src/hints/mod.rs @@ -35,6 +35,7 @@ mod bls_field; mod bls_utils; pub mod builtins; mod compiled_class; +mod compression; mod deprecated_compiled_class; mod execute_transactions; pub mod execution; @@ -254,6 +255,7 @@ fn hints() -> HashMap where hints.insert(compiled_class::SET_AP_TO_SEGMENT_HASH.into(), compiled_class::set_ap_to_segment_hash); hints.insert(secp::READ_EC_POINT_ADDRESS.into(), secp::read_ec_point_from_address); hints.insert(execute_transactions::SHA2_FINALIZE.into(), execute_transactions::sha2_finalize); + hints.insert(compression::COMPRESS.into(), compression::compress); hints } diff --git a/crates/starknet-os/src/hints/output.rs b/crates/starknet-os/src/hints/output.rs index c1054db8..a81a7014 100644 --- a/crates/starknet-os/src/hints/output.rs +++ b/crates/starknet-os/src/hints/output.rs @@ -163,8 +163,7 @@ pub fn set_state_updates_start( Ok(()) } -pub const SET_COMPRESSED_START: &str = indoc! {r#"use_kzg_da = ids.use_kzg_da -if use_kzg_da: +pub const SET_COMPRESSED_START: &str = indoc! {r#"if use_kzg_da: ids.compressed_start = segments.add() else: # Assign a temporary segment, to be relocated into the output segment. diff --git a/crates/starknet-os/src/hints/vars.rs b/crates/starknet-os/src/hints/vars.rs index 3d11d908..e36550be 100644 --- a/crates/starknet-os/src/hints/vars.rs +++ b/crates/starknet-os/src/hints/vars.rs @@ -163,6 +163,9 @@ pub mod ids { pub const N_UPDATES_SMALL_PACKING_BOUND: &str = "starkware.starknet.core.os.state.output.N_UPDATES_SMALL_PACKING_BOUND"; pub const FULL_OUTPUT: &str = "full_output"; + pub const COMPRESSED_DST: &str = "compressed_dst"; + pub const DATA_START: &str = "data_start"; + pub const DATA_END: &str = "data_end"; } pub mod constants { @@ -171,4 +174,5 @@ pub mod constants { pub const MERKLE_HEIGHT: &str = "starkware.starknet.core.os.state.commitment.MERKLE_HEIGHT"; pub const STORED_BLOCK_HASH_BUFFER: &str = "starkware.starknet.core.os.constants.STORED_BLOCK_HASH_BUFFER"; pub const VALIDATED: &str = "starkware.starknet.core.os.constants.VALIDATED"; + pub const HEADER_ELM_N_BITS: &str = "starkware.starknet.core.os.data_availability.compression.HEADER_ELM_N_BITS"; }