From eb6f0088d7060a2a82f5a938294b91d1af72ebfa Mon Sep 17 00:00:00 2001 From: Rohit Narurkar Date: Wed, 1 May 2024 02:27:48 +0100 Subject: [PATCH] witgen (fse::reconstruct) with prob=-1 case covered --- aggregator/src/aggregation/decoder/tables.rs | 4 +- .../decoder/tables/rom_fse_order.rs | 2 +- aggregator/src/aggregation/decoder/witgen.rs | 23 +- .../src/aggregation/decoder/witgen/types.rs | 230 ++++++++++++------ 4 files changed, 181 insertions(+), 78 deletions(-) diff --git a/aggregator/src/aggregation/decoder/tables.rs b/aggregator/src/aggregation/decoder/tables.rs index c4667582d0..dee0f9141d 100644 --- a/aggregator/src/aggregation/decoder/tables.rs +++ b/aggregator/src/aggregation/decoder/tables.rs @@ -18,9 +18,7 @@ pub use rom_fse_order::{FseTableKind, RomFseOrderTable, RomSequencesDataInterlea /// The fixed code to Baseline/NumBits for Literal Length. mod rom_sequence_codes; -pub use rom_sequence_codes::{ - LiteralLengthCodes, MatchLengthCodes, MatchOffsetCodes, RomSequenceCodes, -}; +pub use rom_sequence_codes::RomSequenceCodes; /// Validate the following tag given the tag currently being processed. mod rom_tag; diff --git a/aggregator/src/aggregation/decoder/tables/rom_fse_order.rs b/aggregator/src/aggregation/decoder/tables/rom_fse_order.rs index 7ff8e17dee..01654f2025 100644 --- a/aggregator/src/aggregation/decoder/tables/rom_fse_order.rs +++ b/aggregator/src/aggregation/decoder/tables/rom_fse_order.rs @@ -15,7 +15,7 @@ use crate::aggregation::decoder::witgen::ZstdTag::{ }; /// FSE table variants that we observe in the sequences section. -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] #[allow(clippy::upper_case_acronyms)] pub enum FseTableKind { /// Literal length FSE table. diff --git a/aggregator/src/aggregation/decoder/witgen.rs b/aggregator/src/aggregation/decoder/witgen.rs index 8485954dc6..376456a4da 100644 --- a/aggregator/src/aggregation/decoder/witgen.rs +++ b/aggregator/src/aggregation/decoder/witgen.rs @@ -12,6 +12,8 @@ pub use types::{ZstdTag::*, *}; pub mod util; use util::{be_bits_to_value, increment_idx, le_bits_to_value, value_bits_le}; +use crate::aggregation::decoder::tables::FseTableKind; + const TAG_MAX_LEN: [(ZstdTag, u64); 13] = [ (FrameHeaderDescriptor, 1), (FrameContentSize, 8), @@ -558,9 +560,11 @@ fn process_block_raw( ); let fse_aux_table = FseAuxiliaryTableData { - byte_offset: 0, + block_idx: 0, + table_kind: FseTableKind::LLT, table_size: 0, sym_to_states: BTreeMap::default(), + sym_to_sorted_states: BTreeMap::default(), }; let huffman_weights = HuffmanCodesData { byte_offset: 0, @@ -603,9 +607,11 @@ fn process_block_rle( ); let fse_aux_table = FseAuxiliaryTableData { - byte_offset: 0, + block_idx: 0, + table_kind: FseTableKind::LLT, table_size: 0, sym_to_states: BTreeMap::default(), + sym_to_sorted_states: BTreeMap::default(), }; let huffman_weights = HuffmanCodesData { byte_offset: 0, @@ -660,9 +666,11 @@ fn process_block_zstd( witness_rows.extend_from_slice(&rows); let mut fse_aux_table = FseAuxiliaryTableData { - byte_offset: 0, + block_idx: 0, + table_kind: FseTableKind::LLT, table_size: 0, sym_to_states: BTreeMap::default(), + sym_to_sorted_states: BTreeMap::default(), }; let mut huffman_weights = HuffmanCodesData { byte_offset: 0, @@ -1035,8 +1043,10 @@ fn process_block_zstd_huffman_code( }; // Recover the FSE table for generating Huffman weights + // TODO(ray): this part is redundant however to compile, we have added the required args to the + // ``reconstruct`` method. let (n_fse_bytes, bit_boundaries, table) = - FseAuxiliaryTableData::reconstruct(src, byte_offset + 1) + FseAuxiliaryTableData::reconstruct(src, 1, FseTableKind::LLT, byte_offset + 1) .expect("Reconstructing FSE table should not fail."); // Witness generation @@ -1179,7 +1189,7 @@ fn process_block_zstd_huffman_code( baseline: 0, num_bits: 0, num_emitted: 0, - n_acc: row.9 as u64, + is_state_skipped: false, }, }); } @@ -1386,7 +1396,8 @@ fn process_block_zstd_huffman_code( baseline: fse_row.1, num_bits: fse_row.2, num_emitted: num_emitted as u64, - n_acc: 0, + // TODO(ray): pls check where to get this field from. + is_state_skipped: false, }, huffman_data: HuffmanData::default(), decoded_data: decoded_data.clone(), diff --git a/aggregator/src/aggregation/decoder/witgen/types.rs b/aggregator/src/aggregation/decoder/witgen/types.rs index c3bb20a0d1..83dabe6a0b 100644 --- a/aggregator/src/aggregation/decoder/witgen/types.rs +++ b/aggregator/src/aggregation/decoder/witgen/types.rs @@ -10,6 +10,8 @@ use halo2_proofs::{circuit::Value, plonk::Expression}; use itertools::Itertools; use strum_macros::EnumIter; +use crate::aggregation::decoder::tables::FseTableKind; + use super::{ params::N_BITS_PER_BYTE, util::{bit_length, read_variable_bit_packing, smaller_powers_of_two, value_bits_le}, @@ -522,8 +524,10 @@ pub struct FseTableRow { pub symbol: u64, /// During FSE table decoding, keep track of the number of symbol emitted pub num_emitted: u64, - /// During FSE table decoding, keep track of accumulated states assigned - pub n_acc: u64, + /// A boolean marker to indicate that as per the state transition rules of FSE codes, this + /// state was reached for this symbol, however it was already pre-allocated to a prior symbol, + /// this can happen in case we have symbols with prob=-1. + pub is_state_skipped: bool, } // Used for tracking bit markers for non-byte-aligned bitstream decoding @@ -553,15 +557,19 @@ pub struct FseTableData { /// Auxiliary data accompanying the FSE table's witness values. #[derive(Clone, Debug)] pub struct FseAuxiliaryTableData { - /// The byte offset in the frame at which the FSE table is described. - pub byte_offset: u64, + /// The block index in which this FSE table appears. + pub block_idx: u64, + /// The FSE table kind, variants are: LLT=1, MOT=2, MLT=3. + pub table_kind: FseTableKind, /// The FSE table's size, i.e. 1 << AL (accuracy log). pub table_size: u64, /// A map from FseSymbol (weight) to states, also including fields for that state, for /// instance, the baseline and the number of bits to read from the FSE bitstream. /// - /// For each symbol, the states are in strictly increasing order. + /// For each symbol, the states as per the state transition rule. pub sym_to_states: BTreeMap>, + /// Similar map, but where the states for each symbol are in increasing order (sorted). + pub sym_to_sorted_states: BTreeMap>, } /// Another form of Fse table that has state as key instead of the FseSymbol. @@ -580,7 +588,12 @@ impl FseAuxiliaryTableData { /// with the reconstructed FSE table. After processing the entire bitstream to reconstruct the /// FSE table, if the read bitstream was not byte aligned, then we discard the 1..8 bits from /// the last byte that we read from. - pub fn reconstruct(src: &[u8], byte_offset: usize) -> std::io::Result { + pub fn reconstruct( + src: &[u8], + block_idx: u64, + table_kind: FseTableKind, + byte_offset: usize, + ) -> std::io::Result { // construct little-endian bit-reader. let data = src.iter().skip(byte_offset).cloned().collect::>(); let mut reader = BitReader::endian(Cursor::new(&data), LittleEndian); @@ -596,9 +609,11 @@ impl FseAuxiliaryTableData { bit_boundaries.push((offset, accuracy_log as u64 - 5)); let table_size = 1 << accuracy_log; - let mut sym_to_states = BTreeMap::new(); + //////////////////////////////////////////////////////////////////////////////////////// + //////////////////////////// Parse Normalised Probabilities //////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////// + let mut normalised_probs = BTreeMap::new(); let mut R = table_size; - let mut state = 0x00; let mut symbol = 0; while R > 0 { // number of bits and value read from the variable bit-packed data. @@ -608,30 +623,38 @@ impl FseAuxiliaryTableData { offset += n_bits_read; bit_boundaries.push((offset, value)); + // Number of states allocated to this symbol. + // - prob=-1 => 1 + // - prob=0 => 0 + // - prob>=1 => prob let N = match value { 0 => 1, _ => value - 1, }; + // When a symbol has a value==0, it signifies a case of prob=-1 (or probability "less + // than 1"), where such symbols are allocated states from the end and retreating. In + // such cases, we reset the FSE state, i.e. read accuracy_log number of bits from the + // bitstream with a baseline==0x00. if value == 0 { - unimplemented!("value=0 => prob=-1: scenario unimplemented"); + normalised_probs.insert(symbol, -1); + symbol += 1; } - // When a symbol has a probability of zero, it is followed by a 2-bits repeat flag. This + // When a symbol has a value==1 (prob==0), it is followed by a 2-bits repeat flag. This // repeat flag tells how many probabilities of zeroes follow the current one. It // provides a number ranging from 0 to 3. If it is a 3, another 2-bits repeat flag // follows, and so on. if value == 1 { - sym_to_states.insert(symbol, vec![]); + normalised_probs.insert(symbol, 0); symbol += 1; - loop { let repeat_bits = reader.read::(2)?; offset += 2; bit_boundaries.push((offset, repeat_bits as u64)); for k in 0..repeat_bits { - sym_to_states.insert(symbol + (k as u64), vec![]); + normalised_probs.insert(symbol + (k as u64), 0); } symbol += repeat_bits as u64; @@ -641,56 +664,11 @@ impl FseAuxiliaryTableData { } } - if value >= 2 { - let states = std::iter::once(state) - .chain((1..N).map(|_| { - state += (table_size >> 1) + (table_size >> 3) + 3; - state &= table_size - 1; - state - })) - .sorted() - .collect::>(); - let (smallest_spot_idx, nbs) = smaller_powers_of_two(table_size, N); - let baselines = if N == 1 { - vec![0x00] - } else { - let mut rotated_nbs = nbs.clone(); - rotated_nbs.rotate_left(smallest_spot_idx); - - let mut baselines = std::iter::once(0x00) - .chain(rotated_nbs.iter().scan(0x00, |baseline, nb| { - *baseline += 1 << nb; - Some(*baseline) - })) - .take(N as usize) - .collect::>(); - - baselines.rotate_right(smallest_spot_idx); - baselines - }; - sym_to_states.insert( - symbol, - states - .iter() - .zip(nbs.iter()) - .zip(baselines.iter()) - .map(|((&state, &nb), &baseline)| FseTableRow { - state, - num_bits: nb, - baseline, - symbol, - num_emitted: 0, - n_acc: 0, - }) - .collect(), - ); - - // increment symbol. + // When a symbol has a value>1 (prob>=1), it is allocated that many number of states in + // the FSE table. + if value > 1 { + normalised_probs.insert(symbol, N as i32); symbol += 1; - - // update state. - state += (table_size >> 1) + (table_size >> 3) + 3; - state &= table_size - 1; } // remove N slots from a total of R. @@ -709,13 +687,127 @@ impl FseAuxiliaryTableData { )); } + //////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////// Allocate States to Symbols /////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////// + let mut sym_to_states = BTreeMap::new(); + let mut sym_to_sorted_states = BTreeMap::new(); + let mut state = 0; + let mut retreating_state = table_size - 1; + let mut allocated_states = HashMap::::new(); + + // We start with the symbols that have prob=-1. + for (&symbol, _prob) in normalised_probs + .iter() + .filter(|(_symbol, &prob)| prob == -1) + { + allocated_states.insert(symbol, true); + let fse_table_row = FseTableRow { + state: retreating_state, + num_bits: accuracy_log as u64, + baseline: 0, + symbol, + is_state_skipped: false, + num_emitted: 0, + }; + sym_to_states.insert(symbol, vec![fse_table_row.clone()]); + sym_to_sorted_states.insert(symbol, vec![fse_table_row]); + retreating_state -= 1; + } + + // We now move to the symbols with prob>=1. + for (&symbol, &prob) in normalised_probs + .iter() + .filter(|(_symbol, &prob)| prob.is_positive()) + { + let N = prob as usize; + let mut count = 0; + let mut states_with_skipped: Vec<(u64, bool)> = Vec::with_capacity(N); + while count < N { + if allocated_states.get(&state).is_some() { + // if state has been pre-allocated to some symbol with prob=-1. + states_with_skipped.push((state, true)); + } else { + // if state is not yet allocated, i.e. available for this symbol. + states_with_skipped.push((state, false)); + count += 1; + } + + // update state. + state += (table_size >> 1) + (table_size >> 3) + 3; + state &= table_size - 1; + } + let sorted_states = states_with_skipped + .iter() + .filter(|&(_s, is_state_skipped)| !is_state_skipped) + .map(|&(s, _)| s) + .sorted() + .collect::>(); + let (smallest_spot_idx, nbs) = smaller_powers_of_two(table_size, N as u64); + let baselines = if N == 1 { + vec![0x00] + } else { + let mut rotated_nbs = nbs.clone(); + rotated_nbs.rotate_left(smallest_spot_idx); + + let mut baselines = std::iter::once(0x00) + .chain(rotated_nbs.iter().scan(0x00, |baseline, nb| { + *baseline += 1 << nb; + Some(*baseline) + })) + .take(N) + .collect::>(); + + baselines.rotate_right(smallest_spot_idx); + baselines + }; + sym_to_states.insert( + symbol, + states_with_skipped + .iter() + .map(|&(s, is_state_skipped)| { + let (baseline, nb) = match sorted_states.iter().position(|&ss| ss == s) { + None => (0, 0), + Some(sorted_idx) => (baselines[sorted_idx], nbs[sorted_idx]), + }; + FseTableRow { + state: s, + num_bits: nb, + baseline, + symbol, + num_emitted: 0, + is_state_skipped, + } + }) + .collect(), + ); + sym_to_sorted_states.insert( + symbol, + sorted_states + .iter() + .zip(nbs.iter()) + .zip(baselines.iter()) + .map(|((&s, &nb), &baseline)| FseTableRow { + state: s, + num_bits: nb, + baseline, + symbol, + num_emitted: 0, + is_state_skipped: false, + }) + .collect(), + ); + } + Ok(( t, bit_boundaries, Self { - byte_offset: byte_offset as u64, + block_idx, + table_kind, table_size, sym_to_states, + sym_to_sorted_states, }, )) } @@ -785,14 +877,15 @@ mod tests { // sure FSE reconstruction ignores them. let src = vec![0xff, 0xff, 0xff, 0x30, 0x6f, 0x9b, 0x03, 0xff, 0xff, 0xff]; - let (n_bytes, _bit_boundaries, table) = FseAuxiliaryTableData::reconstruct(&src, 3)?; + let (n_bytes, _bit_boundaries, table) = + FseAuxiliaryTableData::reconstruct(&src, 1, FseTableKind::LLT, 3)?; // TODO: assert equality for the entire table. // for now only comparing state/baseline/nb for S1, i.e. weight == 1. assert_eq!(n_bytes, 4); assert_eq!( - table.sym_to_states.get(&1).cloned().unwrap(), + table.sym_to_sorted_states.get(&1).cloned().unwrap(), [ (0x03, 0x10, 3), (0x0c, 0x18, 3), @@ -803,13 +896,13 @@ mod tests { ] .iter() .enumerate() - .map(|(i, &(state, baseline, num_bits))| FseTableRow { + .map(|(_i, &(state, baseline, num_bits))| FseTableRow { state, symbol: 1, baseline, num_bits, num_emitted: 0, - n_acc: 0, + is_state_skipped: false, }) .collect::>(), ); @@ -823,7 +916,8 @@ mod tests { 0x21, 0x9d, 0x51, 0xcc, 0x18, 0x42, 0x44, 0x81, 0x8c, 0x94, 0xb4, 0x50, 0x1e, ]; - let (n_bytes, _bit_boundaries, table) = FseAuxiliaryTableData::reconstruct(&src, 0)?; + let (n_bytes, _bit_boundaries, table) = + FseAuxiliaryTableData::reconstruct(&src, 1, FseTableKind::LLT, 0)?; let parsed_state_map = table.parse_state_table(); // TODO: assertions