diff --git a/packages/compiler/src/bin/compiler.rs b/packages/compiler/src/bin/compiler.rs index 34966f3..ba53749 100644 --- a/packages/compiler/src/bin/compiler.rs +++ b/packages/compiler/src/bin/compiler.rs @@ -1,6 +1,52 @@ -use clap::{Parser, Subcommand}; +//! ZK Regex Compiler CLI +//! +//! This binary provides a command-line interface for the ZK Regex Compiler. +//! It supports two main commands: `Decomposed` for working with decomposed regex files, +//! and `Raw` for working with raw regex strings. +//! +//! # Usage +//! +//! ## Decomposed Command +//! Process a decomposed regex file: +//! +//! ``` +//! zk-regex decomposed --decomposed-regex-path [OPTIONS] +//! ``` +//! +//! Options: +//! - `-d, --decomposed-regex-path `: Path to the decomposed regex JSON file (required) +//! - `-h, --halo2-dir-path `: Directory path for Halo2 output +//! - `-c, --circom-file-path `: File path for Circom output +//! - `-t, --template-name `: Template name +//! - `-g, --gen-substrs`: Generate substrings +//! +//! Example: +//! ``` +//! zk-regex decomposed -d regex.json -h ./halo2_output -c ./circom_output.circom -t MyTemplate -g true +//! ``` +//! +//! ## Raw Command +//! Process a raw regex string: +//! +//! ``` +//! zk-regex raw --raw-regex [OPTIONS] +//! ``` +//! +//! Options: +//! - `-r, --raw-regex `: Raw regex string (required) +//! - `-s, --substrs-json-path `: Path to substrings JSON file +//! - `-h, --halo2-dir-path `: Directory path for Halo2 output +//! - `-c, --circom-file-path `: File path for Circom output +//! - `-t, --template-name `: Template name +//! - `-g, --gen-substrs`: Generate substrings +//! +//! Example: +//! ``` +//! zk-regex raw -r "a*b+c?" -s substrings.json -h ./halo2_output -c ./circom_output.circom -t MyTemplate -g true +//! ``` -use zk_regex_compiler::*; +use clap::{Parser, Subcommand}; +use zk_regex_compiler::{gen_from_decomposed, gen_from_raw}; #[derive(Parser, Debug, Clone)] #[command(author, version, about, long_about = None)] @@ -42,37 +88,53 @@ enum Commands { fn main() { let cli = Cli::parse(); match cli.command { - Commands::Decomposed { - decomposed_regex_path, - halo2_dir_path, - circom_file_path, - template_name, + Commands::Decomposed { .. } => process_decomposed(cli), + Commands::Raw { .. } => process_raw(cli), + } +} + +fn process_decomposed(cli: Cli) { + if let Commands::Decomposed { + decomposed_regex_path, + halo2_dir_path, + circom_file_path, + template_name, + gen_substrs, + } = cli.command + { + if let Err(e) = gen_from_decomposed( + &decomposed_regex_path, + halo2_dir_path.as_deref(), + circom_file_path.as_deref(), + template_name.as_deref(), gen_substrs, - } => { - gen_from_decomposed( - &decomposed_regex_path, - halo2_dir_path.as_ref().map(|s| s.as_str()), - circom_file_path.as_ref().map(|s| s.as_str()), - template_name.as_ref().map(|s| s.as_str()), - gen_substrs, - ); + ) { + eprintln!("Error: {}", e); + std::process::exit(1); } - Commands::Raw { - raw_regex, - substrs_json_path, - halo2_dir_path, - circom_file_path, - template_name, + } +} + +fn process_raw(cli: Cli) { + if let Commands::Raw { + raw_regex, + substrs_json_path, + halo2_dir_path, + circom_file_path, + template_name, + gen_substrs, + } = cli.command + { + if let Err(e) = gen_from_raw( + &raw_regex, + substrs_json_path.as_deref(), + halo2_dir_path.as_deref(), + circom_file_path.as_deref(), + template_name.as_deref(), gen_substrs, - } => { - gen_from_raw( - &raw_regex, - substrs_json_path.as_ref().map(|s| s.as_str()), - halo2_dir_path.as_ref().map(|s| s.as_str()), - circom_file_path.as_ref().map(|s| s.as_str()), - template_name.as_ref().map(|s| s.as_str()), - gen_substrs, - ); + ) { + eprintln!("Error: {}", e); + std::process::exit(1); } } } diff --git a/packages/compiler/src/circom.rs b/packages/compiler/src/circom.rs index 3328f8f..67b288a 100644 --- a/packages/compiler/src/circom.rs +++ b/packages/compiler/src/circom.rs @@ -1,359 +1,392 @@ -use itertools::Itertools; - -use super::CompilerError; -use crate::get_accepted_state; -use crate::DFAGraph; -use crate::RegexAndDFA; -use std::collections::{BTreeMap, BTreeSet}; - -use std::fs::File; -use std::io::Write; -use std::path::PathBuf; - -fn gen_circom_allstr( +use crate::{ + errors::CompilerError, + regex::get_accepted_state, + structs::{DFAGraph, RegexAndDFA}, +}; +use std::{ + collections::{BTreeMap, BTreeSet}, + fs::File, + io::Write, + path::Path, +}; + +/// Builds a reverse graph from a DFA graph and collects accept nodes. +/// +/// This function creates a reverse graph where the direction of edges is inverted, +/// and collects all accepting states. +/// +/// # Arguments +/// +/// * `state_len` - The number of states in the DFA. +/// * `dfa_graph` - A reference to the original DFA graph. +/// +/// # Returns +/// +/// A tuple containing: +/// * The reverse graph as a `BTreeMap>>`. +/// * A `BTreeSet` of accepting state IDs. +/// +/// # Errors +/// +/// Returns a `CompilerError::NoAcceptedState` if no accepting states are found. +fn build_reverse_graph( + state_len: usize, dfa_graph: &DFAGraph, - template_name: &str, - regex_str: &str, - end_anchor: bool, -) -> String { - let n = dfa_graph.states.len(); +) -> (BTreeMap>>, BTreeSet) { let mut rev_graph = BTreeMap::>>::new(); - let mut to_init_graph = vec![]; - // let mut init_going_state: Option = None; + let mut accept_nodes = BTreeSet::::new(); - for i in 0..n { + for i in 0..state_len { rev_graph.insert(i, BTreeMap::new()); - to_init_graph.push(vec![]); } - let mut accept_nodes = BTreeSet::::new(); - - for i in 0..n { - let node = &dfa_graph.states[i]; - for (k, v) in &node.edges { + for (i, node) in dfa_graph.states.iter().enumerate() { + for (k, v) in &node.transitions { let chars: Vec = v.iter().cloned().collect(); - rev_graph.get_mut(k).unwrap().insert(i, chars.clone()); - - if i == 0 { - // if let Some(index) = chars.iter().position(|&x| x == 94) { - // init_going_state = Some(*k); - // rev_graph.get_mut(&k).unwrap().get_mut(&i).unwrap()[index] = 255; - // } - - for j in rev_graph.get(&k).unwrap().get(&i).unwrap() { - if *j == 255 { - continue; - } - to_init_graph[*k].push(*j); - } - } + rev_graph.get_mut(k).unwrap().insert(i, chars); } - if node.r#type == "accept" { + if node.state_type == "accept" { accept_nodes.insert(i); } } - // if let Some(init_going_state) = init_going_state { - // for (going_state, chars) in to_init_graph.iter().enumerate() { - // if chars.is_empty() { - // continue; - // } - - // if rev_graph - // .get_mut(&(going_state as usize)) - // .unwrap() - // .get_mut(&init_going_state) - // .is_none() - // { - // rev_graph - // .get_mut(&(going_state as usize)) - // .unwrap() - // .insert(init_going_state, vec![]); - // } - - // rev_graph - // .get_mut(&(going_state as usize)) - // .unwrap() - // .get_mut(&init_going_state) - // .unwrap() - // .extend_from_slice(chars); - // } - // } - if accept_nodes.is_empty() { panic!("Accept node must exist"); } - let accept_nodes_array: Vec = accept_nodes.into_iter().collect(); + (rev_graph, accept_nodes) +} - if accept_nodes_array.len() != 1 { - panic!("The size of accept nodes must be one"); +/// Optimizes character ranges by grouping consecutive characters and identifying individual characters. +/// +/// This function takes a slice of u8 values (representing ASCII characters) and groups them into +/// ranges where possible, while also identifying individual characters that don't fit into ranges. +/// +/// # Arguments +/// +/// * `k` - A slice of u8 values representing ASCII characters. +/// +/// # Returns +/// +/// A tuple containing: +/// * A Vec of (u8, u8) tuples representing optimized character ranges (min, max). +/// * A BTreeSet of u8 values representing individual characters not included in ranges. +/// +/// # Note +/// +/// Ranges are only created for sequences of 16 or more consecutive characters. +fn optimize_char_ranges(k: &[u8]) -> (Vec<(u8, u8)>, BTreeSet) { + let mut min_maxes = vec![]; + let mut vals = k.iter().cloned().collect::>(); + + if k.is_empty() { + return (min_maxes, vals); } - let mut eq_i = 0; - let mut lt_i = 0; - let mut and_i = 0; - let mut multi_or_i = 0; + let mut cur_min = k[0]; + let mut cur_max = k[0]; - let mut range_checks = vec![vec![None; 256]; 256]; - let mut eq_checks = vec![None; 256]; - let mut multi_or_checks1 = BTreeMap::::new(); - let mut multi_or_checks2 = BTreeMap::::new(); - let mut zero_starting_states = vec![]; - let mut zero_starting_and_idxes = BTreeMap::>::new(); - - let mut lines = vec![]; - // let mut zero_starting_lines = vec![]; - - lines.push("\tfor (var i = 0; i < num_bytes; i++) {".to_string()); - lines.push(format!("\t\tstate_changed[i] = MultiOR({});", n - 1)); - lines.push(format!("\t\tstates[i][0] <== 1;")); - assert!( - rev_graph.get(&0).unwrap().len() == 0, - "state transition to the 0-th state is not allowed" - ); - if end_anchor { - lines.push(format!( - "\t\tpadding_start[i+1] <== IsNotZeroAcc()(padding_start[i], in[i]);" - )); + for &val in &k[1..] { + if cur_max == val { + continue; + } else if cur_max + 1 == val { + cur_max = val; + } else { + if cur_max - cur_min >= 16 { + min_maxes.push((cur_min, cur_max)); + } + cur_min = val; + cur_max = val; + } } - for i in 1..n { - let mut outputs = vec![]; - zero_starting_and_idxes.insert(i, vec![]); - // let mut state_change_lines = vec![]; - for (prev_i, k) in rev_graph.get(&(i as usize)).unwrap().iter() { - let prev_i_num = *prev_i; - if prev_i_num == 0 { - zero_starting_states.push(i); - } - let mut k = k.clone(); - k.retain(|&x| x != 0); - k.sort(); + if cur_max - cur_min >= 16 { + min_maxes.push((cur_min, cur_max)); + } - let mut eq_outputs = vec![]; - let mut vals = k.clone().into_iter().collect::>(); + for (min, max) in &min_maxes { + for code in *min..=*max { + vals.remove(&code); + } + } - if vals.is_empty() { - continue; - } + (min_maxes, vals) +} - let mut min_maxes = vec![]; - let mut cur_min = k[0]; - let mut cur_max = k[0]; - - for idx in 1..k.len() { - if cur_max == k[idx] { - continue; - } else if cur_max + 1 == k[idx] { - cur_max += 1; - } else { - if cur_max - cur_min >= 16 { - min_maxes.push((cur_min, cur_max)); - } - cur_min = k[idx]; - cur_max = k[idx]; - } - } +/// Adds a range check for character comparisons in the Circom circuit. +/// +/// This function either reuses an existing range check or creates a new one, +/// adding the necessary Circom code lines and updating the relevant counters. +/// +/// # Arguments +/// +/// * `lines` - A mutable reference to a Vec of Strings containing Circom code lines. +/// * `range_checks` - A mutable reference to a 2D Vec storing existing range checks. +/// * `eq_outputs` - A mutable reference to a Vec storing equality check outputs. +/// * `min` - The minimum value of the range. +/// * `max` - The maximum value of the range. +/// * `lt_i` - A mutable reference to the current LessThan component index. +/// * `and_i` - A mutable reference to the current AND component index. +fn add_range_check( + lines: &mut Vec, + range_checks: &mut Vec>>, + eq_outputs: &mut Vec<(&str, usize)>, + min: u8, + max: u8, + lt_i: &mut usize, + and_i: &mut usize, +) { + if let Some((_, and_i)) = range_checks[min as usize][max as usize] { + eq_outputs.push(("and", and_i)); + } else { + lines.push(format!("\t\tlt[{}][i] = LessEqThan(8);", *lt_i)); + lines.push(format!("\t\tlt[{}][i].in[0] <== {};", *lt_i, min)); + lines.push(format!("\t\tlt[{}][i].in[1] <== in[i];", *lt_i)); + lines.push(format!("\t\tlt[{}][i] = LessEqThan(8);", *lt_i + 1)); + lines.push(format!("\t\tlt[{}][i].in[0] <== in[i];", *lt_i + 1)); + lines.push(format!("\t\tlt[{}][i].in[1] <== {};", *lt_i + 1, max)); + lines.push(format!("\t\tand[{}][i] = AND();", *and_i)); + lines.push(format!( + "\t\tand[{}][i].a <== lt[{}][i].out;", + *and_i, *lt_i + )); + lines.push(format!( + "\t\tand[{}][i].b <== lt[{}][i].out;", + *and_i, + *lt_i + 1 + )); - if cur_max - cur_min >= 16 { - min_maxes.push((cur_min, cur_max)); - } + eq_outputs.push(("and", *and_i)); + range_checks[min as usize][max as usize] = Some((*lt_i, *and_i)); + *lt_i += 2; + *and_i += 1; + } +} - for min_max in &min_maxes { - for code in min_max.0..=min_max.1 { - vals.remove(&code); - } - } +/// Adds an equality check for a specific character code in the Circom circuit. +/// +/// This function either reuses an existing equality check or creates a new one, +/// adding the necessary Circom code lines and updating the relevant counter. +/// +/// # Arguments +/// +/// * `lines` - A mutable reference to a Vec of Strings containing Circom code lines. +/// * `eq_checks` - A mutable reference to a Vec storing existing equality checks. +/// * `code` - The ASCII code of the character to check for equality. +/// * `eq_i` - A mutable reference to the current equality component index. +/// +/// # Returns +/// +/// The index of the equality check component used or created. +fn add_eq_check( + lines: &mut Vec, + eq_checks: &mut Vec>, + code: u8, + eq_i: &mut usize, +) -> usize { + if let Some(index) = eq_checks[code as usize] { + index + } else { + lines.push(format!("\t\teq[{}][i] = IsEqual();", *eq_i)); + lines.push(format!("\t\teq[{}][i].in[0] <== in[i];", *eq_i)); + lines.push(format!("\t\teq[{}][i].in[1] <== {};", *eq_i, code)); + eq_checks[code as usize] = Some(*eq_i); + let result = *eq_i; + *eq_i += 1; + result + } +} - for min_max in &min_maxes { - let min = min_max.0; - let max = min_max.1; - - if range_checks[min as usize][max as usize].is_none() { - lines.push(format!("\t\tlt[{}][i] = LessEqThan(8);", lt_i)); - lines.push(format!("\t\tlt[{}][i].in[0] <== {};", lt_i, min)); - lines.push(format!("\t\tlt[{}][i].in[1] <== in[i];", lt_i)); - lines.push(format!("\t\tlt[{}][i] = LessEqThan(8);", lt_i + 1)); - lines.push(format!("\t\tlt[{}][i].in[0] <== in[i];", lt_i + 1)); - lines.push(format!("\t\tlt[{}][i].in[1] <== {};", lt_i + 1, max)); - lines.push(format!("\t\tand[{}][i] = AND();", and_i)); - lines.push(format!("\t\tand[{}][i].a <== lt[{}][i].out;", and_i, lt_i)); - lines.push(format!( - "\t\tand[{}][i].b <== lt[{}][i].out;", - and_i, - lt_i + 1 - )); - eq_outputs.push(("and", and_i)); - range_checks[min as usize][max as usize] = Some((lt_i, and_i)); - lt_i += 2; - and_i += 1; - } else { - if let Some((_, and_i)) = range_checks[min as usize][max as usize] { - eq_outputs.push(("and", and_i)); - } - } - } +/// Adds a state transition to the Circom circuit. +/// +/// This function creates an AND gate for the state transition and handles the +/// equality outputs, potentially creating a MultiOR gate if necessary. +/// +/// # Arguments +/// +/// * `lines` - A mutable reference to a Vec of Strings containing Circom code lines. +/// * `zero_starting_and_idxes` - A mutable reference to a BTreeMap storing AND indices for zero-starting states. +/// * `i` - The current state index. +/// * `prev_i` - The previous state index. +/// * `eq_outputs` - A Vec of tuples containing equality output types and indices. +/// * `and_i` - A mutable reference to the current AND gate index. +/// * `multi_or_checks1` - A mutable reference to a BTreeMap storing MultiOR checks. +/// * `multi_or_i` - A mutable reference to the current MultiOR gate index. +fn add_state_transition( + lines: &mut Vec, + zero_starting_and_idxes: &mut BTreeMap>, + i: usize, + prev_i: usize, + eq_outputs: Vec<(&str, usize)>, + and_i: &mut usize, + multi_or_checks1: &mut BTreeMap, + multi_or_i: &mut usize, +) { + lines.push(format!("\t\tand[{}][i] = AND();", and_i)); + lines.push(format!( + "\t\tand[{}][i].a <== states[i][{}];", + and_i, prev_i + )); - for code in &vals { - if eq_checks[*code as usize].is_none() { - lines.push(format!("\t\teq[{}][i] = IsEqual();", eq_i)); - lines.push(format!("\t\teq[{}][i].in[0] <== in[i];", eq_i)); - lines.push(format!("\t\teq[{}][i].in[1] <== {};", eq_i, code)); - eq_outputs.push(("eq", eq_i)); - eq_checks[*code as usize] = Some(eq_i); - eq_i += 1; - } else { - if let Some(eq_i) = eq_checks[*code as usize] { - eq_outputs.push(("eq", eq_i)); - } - } - } - lines.push(format!("\t\tand[{}][i] = AND();", and_i)); + if eq_outputs.len() == 1 { + lines.push(format!( + "\t\tand[{}][i].b <== {}[{}][i].out;", + and_i, eq_outputs[0].0, eq_outputs[0].1 + )); + if prev_i == 0 { + zero_starting_and_idxes.get_mut(&i).unwrap().push(*and_i); + } + } else if eq_outputs.len() > 1 { + let eq_outputs_key = serde_json::to_string(&eq_outputs).unwrap(); + if let Some(&multi_or_index) = multi_or_checks1.get(&eq_outputs_key) { lines.push(format!( - "\t\tand[{}][i].a <== states[i][{}];", - and_i, prev_i_num + "\t\tand[{}][i].b <== multi_or[{}][i].out;", + and_i, multi_or_index )); - - if eq_outputs.len() == 1 { + } else { + lines.push(format!( + "\t\tmulti_or[{}][i] = MultiOR({});", + *multi_or_i, + eq_outputs.len() + )); + for (output_i, (eq_type, eq_i)) in eq_outputs.iter().enumerate() { lines.push(format!( - "\t\tand[{}][i].b <== {}[{}][i].out;", - and_i, eq_outputs[0].0, eq_outputs[0].1 + "\t\tmulti_or[{}][i].in[{}] <== {}[{}][i].out;", + *multi_or_i, output_i, eq_type, eq_i )); - if prev_i_num == 0 { - zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i); - } - } else if eq_outputs.len() > 1 { - let eq_outputs_key = serde_json::to_string(&eq_outputs).unwrap(); - - if multi_or_checks1.get(&eq_outputs_key).is_none() { - lines.push(format!( - "\t\tmulti_or[{}][i] = MultiOR({});", - multi_or_i, - eq_outputs.len() - )); - - for (output_i, (eq_type, eq_i)) in eq_outputs.iter().enumerate() { - lines.push(format!( - "\t\tmulti_or[{}][i].in[{}] <== {}[{}][i].out;", - multi_or_i, output_i, eq_type, eq_i - )); - } - - lines.push(format!( - "\t\tand[{}][i].b <== multi_or[{}][i].out;", - and_i, multi_or_i - )); - if prev_i_num == 0 { - zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i); - } - multi_or_checks1.insert(eq_outputs_key, multi_or_i); - multi_or_i += 1; - } else { - if let Some(multi_or_i) = multi_or_checks1.get(&eq_outputs_key) { - lines.push(format!( - "\t\tand[{}][i].b <== multi_or[{}][i].out;", - and_i, multi_or_i - )); - if prev_i_num == 0 { - zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i); - } - } - } - } - if prev_i_num != 0 { - outputs.push(and_i); } - and_i += 1; + lines.push(format!( + "\t\tand[{}][i].b <== multi_or[{}][i].out;", + *and_i, *multi_or_i + )); + multi_or_checks1.insert(eq_outputs_key, *multi_or_i); + *multi_or_i += 1; } - if outputs.len() == 1 { - if zero_starting_states.contains(&i) { - lines.push(format!( - "\t\tstates_tmp[i+1][{}] <== and[{}][i].out;", - i, outputs[0] - )); - } else { - lines.push(format!( - "\t\tstates[i+1][{}] <== and[{}][i].out;", - i, outputs[0] - )); - } - } else if outputs.len() > 1 { - let outputs_key = serde_json::to_string(&outputs).unwrap(); + if prev_i == 0 { + zero_starting_and_idxes.get_mut(&i).unwrap().push(*and_i); + } + } - if multi_or_checks2.get(&outputs_key).is_none() { + *and_i += 1; +} + +/// Helper function to add a MultiOR gate to the Circom circuit. +fn add_multi_or_gate( + lines: &mut Vec, + outputs: &[usize], + multi_or_i: &mut usize, + i: usize, + state_var: &str, +) { + lines.push(format!( + "\t\tmulti_or[{multi_or_i}][i] = MultiOR({});", + outputs.len() + )); + for (output_i, and_i) in outputs.iter().enumerate() { + lines.push(format!( + "\t\tmulti_or[{multi_or_i}][i].in[{output_i}] <== and[{and_i}][i].out;" + )); + } + lines.push(format!( + "\t\t{state_var}[i+1][{i}] <== multi_or[{multi_or_i}][i].out;" + )); +} + +/// Adds a state update to the Circom circuit. +/// +/// This function handles the update of state variables, potentially creating +/// a MultiOR gate if there are multiple outputs to combine. +/// +/// # Arguments +/// +/// * `lines` - A mutable reference to a Vec of Strings containing Circom code lines. +/// * `i` - The current state index. +/// * `outputs` - A Vec of output indices to be combined. +/// * `zero_starting_states` - A mutable reference to a Vec of zero-starting state indices. +/// * `multi_or_checks2` - A mutable reference to a BTreeMap storing MultiOR checks. +/// * `multi_or_i` - A mutable reference to the current MultiOR gate index. +fn add_state_update( + lines: &mut Vec, + i: usize, + outputs: Vec, + zero_starting_states: &[usize], + multi_or_checks2: &mut BTreeMap, + multi_or_i: &mut usize, +) { + let is_zero_starting = zero_starting_states.contains(&i); + let state_var = if is_zero_starting { + "states_tmp" + } else { + "states" + }; + + match outputs.len() { + 0 => lines.push(format!("\t\t{state_var}[i+1][{i}] <== 0;")), + 1 => lines.push(format!( + "\t\t{state_var}[i+1][{i}] <== and[{}][i].out;", + outputs[0] + )), + _ => { + let outputs_key = serde_json::to_string(&outputs).expect("Failed to serialize outputs"); + if let Some(&multi_or_index) = multi_or_checks2.get(&outputs_key) { lines.push(format!( - "\t\tmulti_or[{}][i] = MultiOR({});", - multi_or_i, - outputs.len() + "\t\t{state_var}[i+1][{i}] <== multi_or[{multi_or_index}][i].out;" )); - - for (output_i, and_i) in outputs.iter().enumerate() { - lines.push(format!( - "\t\tmulti_or[{}][i].in[{}] <== and[{}][i].out;", - multi_or_i, output_i, and_i - )); - } - if zero_starting_states.contains(&i) { - lines.push(format!( - "\t\tstates_tmp[i+1][{}] <== multi_or[{}][i].out;", - i, multi_or_i - )); - } else { - lines.push(format!( - "\t\tstates[i+1][{}] <== multi_or[{}][i].out;", - i, multi_or_i - )); - } - multi_or_checks2.insert(outputs_key, multi_or_i); - multi_or_i += 1; } else { - if let Some(multi_or_i) = multi_or_checks2.get(&outputs_key) { - if zero_starting_states.contains(&i) { - lines.push(format!( - "\t\tstates_tmp[i+1][{}] <== multi_or[{}][i].out;", - i, multi_or_i - )); - } else { - lines.push(format!( - "\t\tstates[i+1][{}] <== multi_or[{}][i].out;", - i, multi_or_i - )); - } - } - } - } else { - if zero_starting_states.contains(&i) { - lines.push(format!("\t\tstates_tmp[i+1][{}] <== 0;", i)); - } else { - lines.push(format!("\t\tstates[i+1][{}] <== 0;", i)); + add_multi_or_gate(lines, &outputs, multi_or_i, i, state_var); + multi_or_checks2.insert(outputs_key, *multi_or_i); + *multi_or_i += 1; } } - - // if zero_starting_states.contains(&i) { - // zero_starting_lines.append(&mut state_change_lines); - // } else { - // lines.append(&mut state_change_lines); - // } } - // let not_zero_starting_states = (1..n) - // .filter(|i| !zero_starting_states.contains(&i)) - // .collect_vec(); +} + +/// Adds the 'from_zero_enabled' logic to the Circom circuit. +/// +/// This function creates a MultiNOR gate that checks if all non-zero states are inactive, +/// which indicates that the current state is the initial (zero) state. +/// +/// # Arguments +/// +/// * `lines` - A mutable reference to a Vec of Strings containing Circom code lines. +/// * `state_len` - The total number of states in the DFA. +/// * `zero_starting_states` - A reference to a Vec of indices of zero-starting states. +fn add_from_zero_enabled( + lines: &mut Vec, + state_len: usize, + zero_starting_states: &Vec, +) { lines.push(format!( "\t\tfrom_zero_enabled[i] <== MultiNOR({})([{}]);", - n - 1, - (1..n) - .map(|i| if zero_starting_states.contains(&i) { + state_len - 1, + (1..state_len) + .map(|i| (if zero_starting_states.contains(&i) { format!("states_tmp[i+1][{}]", i) } else { format!("states[i+1][{}]", i) - }) + })) .collect::>() .join(", ") )); - for (i, vec) in zero_starting_and_idxes.iter() { - if vec.len() == 0 { +} + +/// Adds updates for zero-starting states to the Circom circuit. +/// +/// This function creates MultiOR gates for each zero-starting state, +/// combining the temporary state with the AND outputs of transitions +/// from the zero state, gated by the 'from_zero_enabled' signal. +/// +/// # Arguments +/// +/// * `lines` - A mutable reference to a Vec of Strings containing Circom code lines. +/// * `zero_starting_and_idxes` - A reference to a BTreeMap mapping state indices to their corresponding AND gate indices. +fn add_zero_starting_state_updates( + lines: &mut Vec, + zero_starting_and_idxes: &BTreeMap>, +) { + for (i, vec) in zero_starting_and_idxes { + if vec.is_empty() { continue; } lines.push(format!( @@ -367,33 +400,194 @@ fn gen_circom_allstr( .join(", ") )); } - for i in 1..n { +} + +/// Adds state change detection logic to the Circom circuit. +/// +/// This function creates inputs for the state_changed component, +/// which detects changes in non-zero states between consecutive steps. +/// +/// # Arguments +/// +/// * `lines` - A mutable reference to a Vec of Strings containing Circom code lines. +/// * `state_len` - The total number of states in the DFA. +fn add_state_changed_updates(lines: &mut Vec, state_len: usize) { + for i in 1..state_len { lines.push(format!( "\t\tstate_changed[i].in[{}] <== states[i+1][{}];", i - 1, i )); } +} - // lines.push("\t\tstates[i+1][0] <== 1 - state_changed[i].out;".to_string()); +/// Generates the state transition logic for the Circom circuit. +/// +/// This function creates the core logic for state transitions in the DFA, +/// including range checks, equality checks, and multi-OR operations. +/// +/// # Arguments +/// +/// * `rev_graph` - A reference to the reverse graph of the DFA. +/// * `state_len` - The total number of states in the DFA. +/// * `end_anchor` - A boolean indicating whether an end anchor is present. +/// +/// # Returns +/// +/// A tuple containing: +/// * The number of equality checks used. +/// * The number of less-than checks used. +/// * The number of AND gates used. +/// * The number of multi-OR gates used. +/// * A Vec of Strings containing the generated Circom code lines. +fn generate_state_transition_logic( + rev_graph: &BTreeMap>>, + state_len: usize, + end_anchor: bool, +) -> (usize, usize, usize, usize, Vec) { + let mut eq_i = 0; + let mut lt_i = 0; + let mut and_i = 0; + let mut multi_or_i = 0; - let mut declarations = vec![]; - declarations.push("pragma circom 2.1.5;\n".to_string()); - declarations - .push("include \"@zk-email/zk-regex-circom/circuits/regex_helpers.circom\";\n".to_string()); - declarations.push(format!( - "// regex: {}", - regex_str.replace("\n", "\\n").replace("\r", "\\r"), + let mut range_checks = vec![vec![None; 256]; 256]; + let mut eq_checks = vec![None; 256]; + let mut multi_or_checks1 = BTreeMap::::new(); + let mut multi_or_checks2 = BTreeMap::::new(); + let mut zero_starting_states = vec![]; + let mut zero_starting_and_idxes = BTreeMap::>::new(); + + let mut lines = vec![]; + + lines.push("\tfor (var i = 0; i < num_bytes; i++) {".to_string()); + lines.push(format!( + "\t\tstate_changed[i] = MultiOR({});", + state_len - 1 )); - declarations.push(format!("template {}(msg_bytes) {{", template_name)); - declarations.push("\tsignal input msg[msg_bytes];".to_string()); - declarations.push("\tsignal output out;\n".to_string()); - declarations.push("\tvar num_bytes = msg_bytes+1;".to_string()); - declarations.push("\tsignal in[num_bytes];".to_string()); - declarations.push("\tin[0]<==255;".to_string()); - declarations.push("\tfor (var i = 0; i < msg_bytes; i++) {".to_string()); - declarations.push("\t\tin[i+1] <== msg[i];".to_string()); - declarations.push("\t}\n".to_string()); + lines.push("\t\tstates[i][0] <== 1;".to_string()); + + if end_anchor { + lines.push( + "\t\tpadding_start[i+1] <== IsNotZeroAcc()(padding_start[i], in[i]);".to_string(), + ); + } + + for i in 1..state_len { + let mut outputs = vec![]; + zero_starting_and_idxes.insert(i, vec![]); + + for (prev_i, chars) in rev_graph.get(&i).unwrap_or(&BTreeMap::new()) { + if *prev_i == 0 { + zero_starting_states.push(i); + } + let mut k = chars.clone(); + k.retain(|&x| x != 0); + k.sort(); + + let mut eq_outputs = vec![]; + + let (min_maxes, individual_chars) = optimize_char_ranges(&k); + + for (min, max) in min_maxes { + add_range_check( + &mut lines, + &mut range_checks, + &mut eq_outputs, + min, + max, + &mut lt_i, + &mut and_i, + ); + } + + for &code in &individual_chars { + let eq_index = add_eq_check(&mut lines, &mut eq_checks, code, &mut eq_i); + eq_outputs.push(("eq", eq_index)); + } + + add_state_transition( + &mut lines, + &mut zero_starting_and_idxes, + i, + *prev_i, + eq_outputs, + &mut and_i, + &mut multi_or_checks1, + &mut multi_or_i, + ); + + if *prev_i != 0 { + outputs.push(and_i - 1); + } + } + + add_state_update( + &mut lines, + i, + outputs, + &mut zero_starting_states, + &mut multi_or_checks2, + &mut multi_or_i, + ); + } + + add_from_zero_enabled(&mut lines, state_len, &zero_starting_states); + add_zero_starting_state_updates(&mut lines, &zero_starting_and_idxes); + add_state_changed_updates(&mut lines, state_len); + + lines.push("\t}".to_string()); + + (eq_i, lt_i, and_i, multi_or_i, lines) +} + +/// Generates the declarations for the Circom circuit. +/// +/// This function creates the initial declarations and setup for the Circom template, +/// including pragma, includes, input/output signals, and component declarations. +/// +/// # Arguments +/// +/// * `template_name` - The name of the Circom template. +/// * `regex_str` - The regular expression string. +/// * `state_len` - The total number of states in the DFA. +/// * `eq_i` - The number of equality components. +/// * `lt_i` - The number of less-than components. +/// * `and_i` - The number of AND components. +/// * `multi_or_i` - The number of multi-OR components. +/// * `end_anchor` - A boolean indicating whether an end anchor is present. +/// +/// # Returns +/// +/// A Vec of Strings containing the generated Circom declarations. +fn generate_declarations( + template_name: &str, + regex_str: &str, + state_len: usize, + eq_i: usize, + lt_i: usize, + and_i: usize, + multi_or_i: usize, + end_anchor: bool, +) -> Vec { + let mut declarations = vec![ + "pragma circom 2.1.5;\n".to_string(), + "include \"@zk-email/zk-regex-circom/circuits/regex_helpers.circom\";\n".to_string(), + format!( + "// regex: {}", + regex_str.replace('\n', "\\n").replace('\r', "\\r") + ), + format!("template {}(msg_bytes) {{", template_name), + "\tsignal input msg[msg_bytes];".to_string(), + "\tsignal output out;".to_string(), + "".to_string(), + "\tvar num_bytes = msg_bytes+1;".to_string(), + "\tsignal in[num_bytes];".to_string(), + "\tin[0]<==255;".to_string(), + "\tfor (var i = 0; i < msg_bytes; i++) {".to_string(), + "\t\tin[i+1] <== msg[i];".to_string(), + "\t}".to_string(), + "".to_string(), + ]; if eq_i > 0 { declarations.push(format!("\tcomponent eq[{}][num_bytes];", eq_i)); @@ -411,32 +605,75 @@ fn gen_circom_allstr( declarations.push(format!("\tcomponent multi_or[{}][num_bytes];", multi_or_i)); } - declarations.push(format!("\tsignal states[num_bytes+1][{}];", n)); - declarations.push(format!("\tsignal states_tmp[num_bytes+1][{}];", n)); - declarations.push(format!("\tsignal from_zero_enabled[num_bytes+1];")); - declarations.push(format!("\tfrom_zero_enabled[num_bytes] <== 0;")); - declarations.push("\tcomponent state_changed[num_bytes];\n".to_string()); + declarations.extend([ + format!("\tsignal states[num_bytes+1][{state_len}];"), + format!("\tsignal states_tmp[num_bytes+1][{state_len}];"), + "\tsignal from_zero_enabled[num_bytes+1];".to_string(), + "\tfrom_zero_enabled[num_bytes] <== 0;".to_string(), + "\tcomponent state_changed[num_bytes];".to_string(), + "".to_string(), + ]); + if end_anchor { - declarations.push("\tsignal padding_start[num_bytes+1];".to_string()); - declarations.push("\tpadding_start[0] <== 0;".to_string()); + declarations.extend([ + "\tsignal padding_start[num_bytes+1];".to_string(), + "\tpadding_start[0] <== 0;".to_string(), + ]); } - let mut init_code = vec![]; - // init_code.push("\tstates[0][0] <== 1;".to_string()); - init_code.push(format!("\tfor (var i = 1; i < {}; i++) {{", n)); - init_code.push("\t\tstates[0][i] <== 0;".to_string()); - init_code.push("\t}\n".to_string()); + declarations +} - let mut final_code = declarations - .into_iter() - .chain(init_code) - .chain(lines) - .collect::>(); - final_code.push("\t}".to_string()); +/// Generates the initialization code for the Circom circuit. +/// +/// This function creates the code to initialize all states except the first one to 0. +/// +/// # Arguments +/// +/// * `state_len` - The total number of states in the DFA. +/// +/// # Returns +/// +/// A Vec of Strings containing the generated initialization code. +fn generate_init_code(state_len: usize) -> Vec { + vec![ + format!("\tfor (var i = 1; i < {state_len}; i++) {{"), + "\t\tstates[0][i] <== 0;".to_string(), + "\t}".to_string(), + "".to_string(), + ] +} - let accept_node = accept_nodes_array[0]; +/// Generates the acceptance logic for the Circom circuit. +/// +/// This function creates the code to check if the DFA has reached an accepting state, +/// and handles the end anchor logic if present. +/// +/// # Arguments +/// +/// * `accept_nodes` - A BTreeSet of accepting state indices. +/// * `end_anchor` - A boolean indicating whether an end anchor is present. +/// +/// # Returns +/// +/// A Vec of Strings containing the generated acceptance logic code. +/// +/// # Panics +/// +/// Panics if there are no accept nodes or if there is more than one accept node. +fn generate_accept_logic(accept_nodes: BTreeSet, end_anchor: bool) -> Vec { let mut accept_lines = vec![]; + if accept_nodes.is_empty() { + panic!("Accept node must exist"); + } + + if accept_nodes.len() != 1 { + panic!("The size of accept nodes must be one"); + } + + let accept_node = *accept_nodes.iter().next().unwrap(); + accept_lines.push("".to_string()); accept_lines.push("\tcomponent is_accepted = MultiOR(num_bytes+1);".to_string()); accept_lines.push("\tfor (var i = 0; i <= num_bytes; i++) {".to_string()); @@ -445,17 +682,18 @@ fn gen_circom_allstr( accept_node )); accept_lines.push("\t}".to_string()); + if end_anchor { accept_lines.push("\tsignal end_anchor_check[num_bytes+1][2];".to_string()); accept_lines.push("\tend_anchor_check[0][1] <== 0;".to_string()); accept_lines.push("\tfor (var i = 0; i < num_bytes; i++) {".to_string()); - accept_lines.push(format!( - "\t\tend_anchor_check[i+1][0] <== IsEqual()([i, padding_start[num_bytes]]);", - )); - accept_lines.push(format!( - "\t\tend_anchor_check[i+1][1] <== end_anchor_check[i][1] + states[i][{}] * end_anchor_check[i+1][0];", - accept_node - )); + accept_lines.push( + "\t\tend_anchor_check[i+1][0] <== IsEqual()([i, padding_start[num_bytes]]);" + .to_string(), + ); + accept_lines.push( + format!("\t\tend_anchor_check[i+1][1] <== end_anchor_check[i][1] + states[i][{}] * end_anchor_check[i+1][0];", accept_node) + ); accept_lines.push("\t}".to_string()); accept_lines .push("\tout <== is_accepted.out * end_anchor_check[num_bytes][1];".to_string()); @@ -463,133 +701,304 @@ fn gen_circom_allstr( accept_lines.push("\tout <== is_accepted.out;".to_string()); } - final_code.extend(accept_lines); + accept_lines +} + +/// Generates the complete Circom circuit as a string. +/// +/// This function orchestrates the generation of all parts of the Circom circuit, +/// including declarations, initialization code, state transition logic, and acceptance logic. +/// +/// # Arguments +/// +/// * `dfa_graph` - A reference to the DFA graph. +/// * `template_name` - The name of the Circom template. +/// * `regex_str` - The regular expression string. +/// * `end_anchor` - A boolean indicating whether an end anchor is present. +/// +/// # Returns +/// +/// A String containing the complete Circom circuit code. +fn gen_circom_allstr( + dfa_graph: &DFAGraph, + template_name: &str, + regex_str: &str, + end_anchor: bool, +) -> String { + let state_len = dfa_graph.states.len(); + + let (rev_graph, accept_nodes) = build_reverse_graph(state_len, dfa_graph); + + let (eq_i, lt_i, and_i, multi_or_i, lines) = + generate_state_transition_logic(&rev_graph, state_len, end_anchor); + + let declarations = generate_declarations( + template_name, + regex_str, + state_len, + eq_i, + lt_i, + and_i, + multi_or_i, + end_anchor, + ); + + let init_code = generate_init_code(state_len); + + let accept_lines = generate_accept_logic(accept_nodes, end_anchor); + + let final_code = [declarations, init_code, lines, accept_lines].concat(); final_code.join("\n") } -impl RegexAndDFA { - pub fn gen_circom( - &self, - circom_path: &PathBuf, - template_name: &str, - gen_substrs: bool, - ) -> Result<(), CompilerError> { - let circom = gen_circom_allstr( - &self.dfa_val, - template_name, - &self.regex_str, - self.end_anchor, - ); - let mut circom_file = File::create(circom_path)?; - write!(circom_file, "{}", circom)?; - if gen_substrs { - let substrs = self.add_substrs_constraints()?; - write!(circom_file, "{}", substrs)?; +/// Writes the consecutive logic for the Circom circuit. +/// +/// This function generates the logic to check for consecutive accepted states. +/// +/// # Arguments +/// +/// * `accepted_state` - The index of the accepted state. +/// +/// # Returns +/// +/// A String containing the generated Circom code for consecutive logic. +fn write_consecutive_logic(accepted_state: usize) -> String { + let mut logic = String::new(); + logic += "\n"; + logic += "\tsignal is_consecutive[msg_bytes+1][3];\n"; + logic += "\tis_consecutive[msg_bytes][2] <== 0;\n"; + logic += "\tfor (var i = 0; i < msg_bytes; i++) {\n"; + logic += &format!( + "\t\tis_consecutive[msg_bytes-1-i][0] <== states[num_bytes-i][{accepted_state}] * (1 - is_consecutive[msg_bytes-i][2]) + is_consecutive[msg_bytes-i][2];\n" + ); + logic += + "\t\tis_consecutive[msg_bytes-1-i][1] <== state_changed[msg_bytes-i].out * is_consecutive[msg_bytes-1-i][0];\n"; + logic += &format!( + "\t\tis_consecutive[msg_bytes-1-i][2] <== ORAnd()([(1 - from_zero_enabled[msg_bytes-i+1]), states[num_bytes-i][{accepted_state}], is_consecutive[msg_bytes-1-i][1]]);\n" + ); + logic += "\t}\n"; + logic +} + +/// Writes the previous states logic for the Circom circuit. +/// +/// This function generates the logic to compute previous states based on transitions. +/// +/// # Arguments +/// +/// * `idx` - The index of the current substring. +/// * `ranges` - A slice of references to tuples representing state transitions. +/// +/// # Returns +/// +/// A String containing the generated Circom code for previous states. +fn write_prev_states(idx: usize, ranges: &[&(usize, usize)]) -> String { + let mut prev_states = String::new(); + for (trans_idx, &(cur, _)) in ranges.iter().enumerate() { + if *cur == 0 { + prev_states += &format!( + "\t\tprev_states{idx}[{trans_idx}][i] <== from_zero_enabled[i+1] * states[i+1][{cur}];\n" + ); + } else { + prev_states += &format!( + "\t\tprev_states{idx}[{trans_idx}][i] <== (1 - from_zero_enabled[i+1]) * states[i+1][{cur}];\n" + ); } - circom_file.flush()?; - Ok(()) } + prev_states +} - pub fn gen_circom_str(&self, template_name: &str) -> Result { - let circom = gen_circom_allstr( - &self.dfa_val, - template_name, - &self.regex_str, - self.end_anchor, - ); - let substrs = self.add_substrs_constraints()?; - let result = circom + &substrs; - Ok(result) +/// Writes the substring logic for the Circom circuit. +/// +/// This function generates the logic to compute if a substring is present. +/// +/// # Arguments +/// +/// * `idx` - The index of the current substring. +/// * `ranges` - A slice of references to tuples representing state transitions. +/// +/// # Returns +/// +/// A String containing the generated Circom code for substring logic. +fn write_is_substr(idx: usize, ranges: &[&(usize, usize)]) -> String { + let multi_or_inputs = ranges + .iter() + .enumerate() + .map(|(trans_idx, (_, next))| { + format!("prev_states{idx}[{trans_idx}][i] * states[i+2][{next}]") + }) + .collect::>() + .join(", "); + + format!( + "\t\tis_substr{idx}[i] <== MultiOR({})([{multi_or_inputs}]);\n", + ranges.len() + ) +} + +/// Writes the reveal logic for the Circom circuit. +/// +/// This function generates the logic to reveal a substring if it's present and consecutive. +/// +/// # Arguments +/// +/// * `idx` - The index of the current substring. +/// +/// # Returns +/// +/// A String containing the generated Circom code for reveal logic. +fn write_is_reveal_and_reveal(idx: usize) -> String { + let mut reveal = String::new(); + reveal += &format!( + "\t\tis_reveal{idx}[i] <== MultiAND(3)([out, is_substr{idx}[i], is_consecutive[i][2]]);\n" + ); + reveal += &format!("\t\treveal{idx}[i] <== in[i+1] * is_reveal{idx}[i];\n"); + reveal +} + +/// Writes the complete substring logic for the Circom circuit. +/// +/// This function combines all substring-related logic into a single block. +/// +/// # Arguments +/// +/// * `idx` - The index of the current substring. +/// * `ranges` - A slice of tuples representing state transitions. +/// +/// # Returns +/// +/// A String containing the generated Circom code for the complete substring logic. +fn write_substr_logic(idx: usize, ranges: &[(usize, usize)]) -> String { + let mut logic = String::new(); + logic += &format!("\tsignal prev_states{idx}[{}][msg_bytes];\n", ranges.len()); + logic += &format!("\tsignal is_substr{idx}[msg_bytes];\n"); + logic += &format!("\tsignal is_reveal{idx}[msg_bytes];\n"); + logic += &format!("\tsignal output reveal{idx}[msg_bytes];\n"); + logic += "\tfor (var i = 0; i < msg_bytes; i++) {\n"; + + let sorted_ranges = sort_ranges(ranges); + logic += &format!( + "\t\t // the {idx}-th substring transitions: {:?}\n", + sorted_ranges + ); + + logic += &write_prev_states(idx, &sorted_ranges); + logic += &write_is_substr(idx, &sorted_ranges); + logic += &write_is_reveal_and_reveal(idx); + + logic += "\t}\n"; + logic +} + +/// Sorts the ranges of state transitions. +/// +/// # Arguments +/// +/// * `ranges` - A slice of tuples representing state transitions. +/// +/// # Returns +/// +/// A Vec of references to the sorted ranges. +fn sort_ranges(ranges: &[(usize, usize)]) -> Vec<&(usize, usize)> { + let mut sorted = ranges.iter().collect::>(); + sorted.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1))); + sorted +} + +/// Adds substring constraints to the Circom circuit. +/// +/// This function generates the logic for substring matching and consecutive state tracking. +/// +/// # Arguments +/// +/// * `regex_dfa` - A reference to the RegexAndDFA struct containing the DFA and substring information. +/// +/// # Returns +/// +/// A Result containing the generated Circom code as a String, or a CompilerError. +fn add_substrs_constraints(regex_dfa: &RegexAndDFA) -> Result { + let accepted_state = + get_accepted_state(®ex_dfa.dfa).ok_or(CompilerError::NoAcceptedState)?; + let mut circom = String::new(); + + circom += &write_consecutive_logic(accepted_state); + + circom += &format!( + "\t// substrings calculated: {:?}\n", + regex_dfa.substrings.substring_ranges + ); + + for (idx, ranges) in regex_dfa.substrings.substring_ranges.iter().enumerate() { + circom += &write_substr_logic(idx, &ranges.iter().copied().collect::>()); } - pub fn add_substrs_constraints(&self) -> Result { - let accepted_state = get_accepted_state(&self.dfa_val).unwrap(); - let mut circom: String = "".to_string(); - circom += "\n"; - circom += "\tsignal is_consecutive[msg_bytes+1][3];\n"; - circom += "\tis_consecutive[msg_bytes][2] <== 0;\n"; - circom += "\tfor (var i = 0; i < msg_bytes; i++) {\n"; - circom += &format!("\t\tis_consecutive[msg_bytes-1-i][0] <== states[num_bytes-i][{}] * (1 - is_consecutive[msg_bytes-i][2]) + is_consecutive[msg_bytes-i][2];\n", accepted_state); - circom += "\t\tis_consecutive[msg_bytes-1-i][1] <== state_changed[msg_bytes-i].out * is_consecutive[msg_bytes-1-i][0];\n"; - circom += &format!("\t\tis_consecutive[msg_bytes-1-i][2] <== ORAnd()([(1 - from_zero_enabled[msg_bytes-i+1]), states[num_bytes-i][{}], is_consecutive[msg_bytes-1-i][1]]);\n", accepted_state); - circom += "\t}\n"; - - let substr_defs_array = &self.substrs_defs.substr_defs_array; - circom += &format!( - "\t// substrings calculated: {:?}\n", - &self.substrs_defs.substr_defs_array - ); - for (idx, defs) in substr_defs_array.into_iter().enumerate() { - let num_defs = defs.len(); - circom += &format!("\tsignal prev_states{}[{}][msg_bytes];\n", idx, defs.len()); - circom += &format!("\tsignal is_substr{}[msg_bytes];\n", idx); - circom += &format!("\tsignal is_reveal{}[msg_bytes];\n", idx); - circom += &format!("\tsignal output reveal{}[msg_bytes];\n", idx); - circom += "\tfor (var i = 0; i < msg_bytes; i++) {\n"; - // circom += &format!("\t\tis_substr{}[i][0] <== 0;\n", idx); - let mut defs = defs.iter().collect::>(); - defs.sort_by(|a, b| { - let start_cmp = a.0.cmp(&b.0); - let end_cmp = a.1.cmp(&b.1); - if start_cmp == std::cmp::Ordering::Equal { - end_cmp - } else { - start_cmp - } - }); - circom += &format!("\t\t // the {}-th substring transitions: {:?}\n", idx, defs); - for (trans_idx, (cur, _)) in defs.iter().enumerate() { - if *cur == 0 { - circom += &format!( - "\t\tprev_states{}[{}][i] <== from_zero_enabled[i+1] * states[i+1][{}];\n", - idx, trans_idx, cur - ); - } else { - circom += &format!( - "\t\tprev_states{}[{}][i] <== (1 - from_zero_enabled[i+1]) * states[i+1][{}];\n", - idx, - trans_idx, - cur - ); - } - } - circom += &format!( - "\t\tis_substr{}[i] <== MultiOR({})([{}]);\n", - idx, - num_defs, - defs.iter() - .enumerate() - .map(|(trans_idx, (_, next))| format!( - "prev_states{}[{}][i] * states[i+2][{}]", - idx, trans_idx, next - )) - .collect::>() - .join(", ") - ); - // for (j, (cur, next)) in defs.iter().enumerate() { - // circom += &format!( - // "\t\tis_substr{}[i][{}] <== is_substr{}[i][{}] + ", - // idx, - // j + 1, - // idx, - // j - // ); - // circom += &format!("states[i+1][{}] * states[i+2][{}];\n", cur, next); - // // if j != defs.len() - 1 { - // // circom += " + "; - // // } else { - // // circom += ";\n"; - // // } - // } - circom += &format!( - "\t\tis_reveal{}[i] <== MultiAND(3)([out, is_substr{}[i], is_consecutive[i][2]]);\n", - idx, idx - ); - circom += &format!("\t\treveal{}[i] <== in[i+1] * is_reveal{}[i];\n", idx, idx); - circom += "\t}\n"; - } - circom += "}"; - Ok(circom) + circom += "}"; + Ok(circom) +} + +/// Generates a Circom template file for the given regex and DFA. +/// +/// This function creates a Circom file containing the circuit logic for the regex matcher. +/// +/// # Arguments +/// +/// * `regex_and_dfa` - A reference to the RegexAndDFA struct containing the regex and DFA information. +/// * `circom_path` - The path where the generated Circom file should be saved. +/// * `template_name` - The name of the Circom template. +/// * `gen_substrs` - A boolean indicating whether to generate substring constraints. +/// +/// # Returns +/// +/// A Result indicating success or a CompilerError. +pub(crate) fn gen_circom_template( + regex_and_dfa: &RegexAndDFA, + circom_path: &Path, + template_name: &str, + gen_substrs: bool, +) -> Result<(), CompilerError> { + let circom = gen_circom_allstr( + ®ex_and_dfa.dfa, + template_name, + ®ex_and_dfa.regex_pattern, + regex_and_dfa.has_end_anchor, + ); + + let mut file = File::create(circom_path)?; + file.write_all(circom.as_bytes())?; + + if gen_substrs { + let substrs = add_substrs_constraints(regex_and_dfa)?; + file.write_all(substrs.as_bytes())?; } + + file.flush()?; + Ok(()) +} + +/// Generates a Circom circuit as a string for the given regex and DFA. +/// +/// This function creates a string containing the Circom circuit logic for the regex matcher. +/// +/// # Arguments +/// +/// * `regex_and_dfa` - A reference to the RegexAndDFA struct containing the regex and DFA information. +/// * `template_name` - The name of the Circom template. +/// +/// # Returns +/// +/// A Result containing the generated Circom code as a String, or a CompilerError. +pub(crate) fn gen_circom_string( + regex_and_dfa: &RegexAndDFA, + template_name: &str, +) -> Result { + let circom = gen_circom_allstr( + ®ex_and_dfa.dfa, + template_name, + ®ex_and_dfa.regex_pattern, + regex_and_dfa.has_end_anchor, + ); + let substrs = add_substrs_constraints(regex_and_dfa)?; + let result = circom + &substrs; + Ok(result) } diff --git a/packages/compiler/src/dfa_tests.json b/packages/compiler/src/dfa_tests.json new file mode 100644 index 0000000..056ebc7 --- /dev/null +++ b/packages/compiler/src/dfa_tests.json @@ -0,0 +1,41 @@ +[ + { + "regex": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$", + "pass": ["user@example.com", "john.doe123@sub.domain.co.uk"], + "fail": ["@example.com", "user@.com", "user@com", "user@example.c"] + }, + { + "regex": "^\\d{3}-\\d{3}-\\d{4}$", + "pass": ["123-456-7890", "000-000-0000"], + "fail": ["123-45-6789", "12-345-6789", "123-456-789", "abc-def-ghij"] + }, + { + "regex": "^(https?:\\/\\/)?([\\da-z\\.-]+)\\.([a-z\\.]{2,6})([\\/\\w \\.-]*)*\\/?$", + "pass": [ + "http://example.com", + "https://sub.domain.co.uk/page", + "www.example.com" + ], + "fail": ["htp://invalid", "http://.com", "https://example."] + }, + { + "regex": "^[0-9]{5}(-[0-9]{4})?$", + "pass": ["12345", "12345-6789"], + "fail": ["1234", "123456", "12345-", "12345-67890"] + }, + { + "regex": "^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$", + "pass": ["#123abc", "#FFF", "#000000"], + "fail": ["123abc", "#GGGGGG", "#FFG", "#F0F0F0F"] + }, + { + "regex": "^([01]?[0-9]|2[0-3]):[0-5][0-9]$", + "pass": ["00:00", "23:59", "1:23", "12:34"], + "fail": ["24:00", "12:60", "1:2", "00:0"] + }, + { + "regex": "^[a-zA-Z]{2,}\\s[a-zA-Z]{1,}'?-?[a-zA-Z]{2,}\\s?([a-zA-Z]{1,})?$", + "pass": ["John Doe", "Mary Jane", "Robert O'Neill", "Sarah Jane-Smith"], + "fail": ["J D", "John", "John Doe", "12John Doe"] + } +] diff --git a/packages/compiler/src/errors.rs b/packages/compiler/src/errors.rs new file mode 100644 index 0000000..cae67a3 --- /dev/null +++ b/packages/compiler/src/errors.rs @@ -0,0 +1,27 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum CompilerError { + #[error("Failed to open file: {0}")] + FileOpenError(#[from] std::io::Error), + #[error("Failed to parse JSON: {0}")] + JsonParseError(#[from] serde_json::Error), + #[error("{0}")] + GenericError(String), + #[error( + "Failed to build DFA for regex: \"{regex}\", please check your regex. Error: {source}" + )] + BuildError { + regex: String, + #[source] + source: regex_automata::dfa::dense::BuildError, + }, + #[error("Error in Regex: {0}")] + RegexError(#[from] regex::Error), + #[error("Parse Error: {0}")] + ParseError(String), + #[error("Graph Error: {0}")] + GraphError(String), + #[error("No accepted state found in DFA")] + NoAcceptedState, +} diff --git a/packages/compiler/src/halo2.rs b/packages/compiler/src/halo2.rs index 21954ba..d8e2048 100644 --- a/packages/compiler/src/halo2.rs +++ b/packages/compiler/src/halo2.rs @@ -1,97 +1,96 @@ -use std::fs::File; -use std::io::{BufWriter, Write}; -use std::path::PathBuf; +use crate::{ + errors::CompilerError, + regex::{get_accepted_state, get_max_state}, + structs::RegexAndDFA, +}; +use std::{ + fs::File, + io::{BufWriter, Write}, + path::PathBuf, +}; -use crate::{get_accepted_state, get_max_state, CompilerError, RegexAndDFA}; +/// Converts a RegexAndDFA structure to a text representation of the DFA. +/// +/// # Arguments +/// +/// * `regex_and_dfa` - A reference to the RegexAndDFA structure. +/// +/// # Returns +/// +/// A String containing the text representation of the DFA. +fn dfa_to_regex_def_text(regex_and_dfa: &RegexAndDFA) -> String { + let accepted_state = get_accepted_state(®ex_and_dfa.dfa).unwrap(); + let max_state = get_max_state(®ex_and_dfa.dfa); + let mut text = format!("0\n{}\n{}\n", accepted_state, max_state); -impl RegexAndDFA { - pub fn gen_halo2_tables( - &self, - allstr_file_path: &PathBuf, - substr_file_paths: &[PathBuf], - gen_substrs: bool, - ) -> Result<(), CompilerError> { - let regex_text = self.dfa_to_regex_def_text(); - let mut regex_file = File::create(allstr_file_path)?; - write!(regex_file, "{}", regex_text)?; - regex_file.flush()?; - - if !gen_substrs { - return Ok(()); - } - - for (idx, defs) in self.substrs_defs.substr_defs_array.iter().enumerate() { - let mut writer = BufWriter::new(File::create(&substr_file_paths[idx])?); - let (starts, ends) = &self.substrs_defs.substr_endpoints_array.as_ref().unwrap()[idx]; - let starts_str = starts - .iter() - .map(|s| s.to_string()) - .collect::>() - .join(" "); - writer.write_fmt(format_args!("{}\n", starts_str))?; - let ends_str = ends - .iter() - .map(|e| e.to_string()) - .collect::>() - .join(" "); - writer.write_fmt(format_args!("{}\n", ends_str))?; - - let mut defs = defs.iter().collect::>(); - defs.sort_by(|a, b| { - let start_cmp = a.0.cmp(&b.0); - if start_cmp == std::cmp::Ordering::Equal { - a.1.cmp(&b.1) - } else { - start_cmp - } - }); - - for (cur, next) in defs.iter() { - writer.write_fmt(format_args!("{} {}\n", cur, next))?; + for (i, state) in regex_and_dfa.dfa.states.iter().enumerate() { + for (next_state, chars) in state.transitions.iter() { + for &char in chars { + text += &format!("{} {} {}\n", i, next_state, char as u8); } } - Ok(()) } + text +} - pub fn dfa_to_regex_def_text(&self) -> String { - let accepted_state = get_accepted_state(&self.dfa_val).unwrap(); - let max_state = get_max_state(&self.dfa_val); - let mut text = format!("0\n{}\n{}\n", accepted_state, max_state); +/// Generates Halo2 tables from a RegexAndDFA structure. +/// +/// # Arguments +/// +/// * `regex_and_dfa` - A reference to the RegexAndDFA structure. +/// * `allstr_file_path` - The path where the main DFA definition will be written. +/// * `substr_file_paths` - A slice of paths where substring definitions will be written. +/// * `gen_substrs` - A boolean indicating whether to generate substring files. +/// +/// # Returns +/// +/// A Result indicating success or containing a CompilerError. +pub(crate) fn gen_halo2_tables( + regex_and_dfa: &RegexAndDFA, + allstr_file_path: &PathBuf, + substr_file_paths: &[PathBuf], + gen_substrs: bool, +) -> Result<(), CompilerError> { + let regex_text = dfa_to_regex_def_text(regex_and_dfa); + std::fs::write(allstr_file_path, regex_text)?; - for (i, state) in self.dfa_val.states.iter().enumerate() { - for (next_state, chars) in state.edges.iter() { - for &char in chars { - let char_u8 = char as u8; - text += &format!("{} {} {}\n", i, next_state, char_u8); - } - } - } - text + if !gen_substrs { + return Ok(()); } -} -#[cfg(test)] -mod tests { - use crate::{DecomposedRegexConfig, RegexPartConfig}; - use std::collections::VecDeque; + for (idx, defs) in regex_and_dfa.substrings.substring_ranges.iter().enumerate() { + let mut writer = BufWriter::new(File::create(&substr_file_paths[idx])?); + let (starts, ends) = ®ex_and_dfa + .substrings + .substring_boundaries + .as_ref() + .unwrap()[idx]; - #[test] - fn test_dfa_to_regex_def_text() { - let regex_part_config = RegexPartConfig { - is_public: false, - regex_def: "m[01]+-[ab];".to_string(), - }; - let mut decomposed_regex_config = DecomposedRegexConfig { - parts: VecDeque::from(vec![regex_part_config]), - }; + writeln!( + writer, + "{}", + starts + .iter() + .map(ToString::to_string) + .collect::>() + .join(" ") + )?; + writeln!( + writer, + "{}", + ends.iter() + .map(ToString::to_string) + .collect::>() + .join(" ") + )?; - let regex_and_dfa = decomposed_regex_config - .to_regex_and_dfa() - .expect("failed to convert the decomposed regex to dfa"); + let mut sorted_defs: Vec<_> = defs.iter().collect(); + sorted_defs.sort_unstable_by_key(|&(start, end)| (*start, *end)); - let regex_def_text = regex_and_dfa.dfa_to_regex_def_text(); - let expected_text = - "0\n5\n5\n0 1 109\n1 2 48\n1 2 49\n2 2 48\n2 2 49\n2 3 45\n3 4 97\n3 4 98\n4 5 59\n"; - assert_eq!(regex_def_text, expected_text); + for &(cur, next) in &sorted_defs { + writeln!(writer, "{} {}", cur, next)?; + } } + + Ok(()) } diff --git a/packages/compiler/src/lib.rs b/packages/compiler/src/lib.rs index df56921..3a7fa04 100644 --- a/packages/compiler/src/lib.rs +++ b/packages/compiler/src/lib.rs @@ -1,173 +1,152 @@ -use std::fs::File; -use std::iter::FromIterator; -pub mod circom; -pub mod halo2; -pub mod regex; - -#[cfg(target_arch = "wasm32")] +mod circom; +mod errors; +mod halo2; +mod regex; +mod structs; mod wasm; -#[cfg(target_arch = "wasm32")] -pub use crate::wasm::*; -// #[cfg(test)] -// mod tests; -use crate::regex::*; +use circom::gen_circom_template; +use errors::CompilerError; +use halo2::gen_halo2_tables; use itertools::Itertools; -use petgraph::prelude::*; -use serde::{Deserialize, Serialize}; -use std::collections::{BTreeMap, BTreeSet, VecDeque}; -use std::path::PathBuf; -use thiserror::Error; - -/// Error definitions of the compiler. -#[derive(Error, Debug)] -pub enum CompilerError { - #[error("No edge from {:?} to {:?} in the graph",.0,.1)] - NoEdge(NodeIndex, NodeIndex), - #[error(transparent)] - IoError(#[from] std::io::Error), - #[error(transparent)] - RegexError(#[from] fancy_regex::Error), - #[error("Generic error: {0}")] - GenericError(String), -} - -/// A configuration of decomposed regexes. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DecomposedRegexConfig { - pub parts: VecDeque, -} - -/// Decomposed regex part. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RegexPartConfig { - /// A flag indicating whether the substring matching with `regex_def` should be exposed. - pub is_public: bool, - /// A regex string. - pub regex_def: String, - // Maximum byte size of the substring in this part. - // pub max_size: usize, - // (Optional) A solidity type of the substring in this part, e.g., "String", "Int", "Decimal". - // pub solidity: Option, -} -/// Solidity type of the substring. -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum SoldityType { - String, - Uint, - Decimal, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DFAState { - r#type: String, - state: usize, - edges: BTreeMap>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DFAGraph { - pub states: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RegexAndDFA { - // pub max_byte_size: usize, - // Original regex string, only here to be printed in generated file to make it more reproducible - pub regex_str: String, - pub dfa_val: DFAGraph, - pub end_anchor: bool, - pub substrs_defs: SubstrsDefs, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SubstrsDefs { - pub substr_defs_array: Vec>, - pub substr_endpoints_array: Option, BTreeSet)>>, - // pub max_bytes: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SubstrsDefsJson { - pub transitions: Vec>, -} - -impl DecomposedRegexConfig { - pub fn to_regex_and_dfa(&mut self) -> Result { - regex_and_dfa(self) - } -} - -impl RegexAndDFA { - pub fn from_regex_str_and_substr_defs( - // max_byte_size: usize, - regex_str: &str, - substrs_defs_json: SubstrsDefsJson, - ) -> Result { - let dfa_val = dfa_from_regex_str(regex_str); - let substr_defs_array = substrs_defs_json - .transitions - .into_iter() - .map(|transitions_array| BTreeSet::<(usize, usize)>::from_iter(transitions_array)) - .collect_vec(); - let substrs_defs = SubstrsDefs { - substr_defs_array, - substr_endpoints_array: None, - // max_bytes: None, - }; - - Ok(RegexAndDFA { - // max_byte_size, - regex_str: regex_str.to_string(), - dfa_val, - end_anchor: regex_str.ends_with('$'), - substrs_defs, - }) +use regex::{create_regex_and_dfa_from_str_and_defs, get_regex_and_dfa}; +use std::{fs::File, path::PathBuf}; +use structs::{DecomposedRegexConfig, RegexAndDFA, SubstringDefinitionsJson}; + +/// Loads substring definitions from a JSON file or creates a default one. +/// +/// # Arguments +/// +/// * `substrs_json_path` - An optional path to the JSON file containing substring definitions. +/// +/// # Returns +/// +/// A `Result` containing either the loaded `SubstringDefinitionsJson` or a `CompilerError`. +fn load_substring_definitions_json( + substrs_json_path: Option<&str>, +) -> Result { + match substrs_json_path { + Some(path) => { + let file = File::open(path)?; + serde_json::from_reader(file).map_err(CompilerError::JsonParseError) + } + None => Ok(SubstringDefinitionsJson { + transitions: vec![vec![]], + }), } } -pub fn gen_from_decomposed( - decomposed_regex_path: &str, +/// Generates output files for Halo2 and Circom based on the provided regex and DFA. +/// +/// # Arguments +/// +/// * `regex_and_dfa` - The `RegexAndDFA` struct containing the regex pattern and DFA. +/// * `halo2_dir_path` - An optional path to the directory for Halo2 output files. +/// * `circom_file_path` - An optional path to the Circom output file. +/// * `circom_template_name` - An optional name for the Circom template. +/// * `num_public_parts` - The number of public parts in the regex. +/// * `gen_substrs` - A boolean indicating whether to generate substrings. +/// +/// # Returns +/// +/// A `Result` indicating success or a `CompilerError`. +fn generate_outputs( + regex_and_dfa: &RegexAndDFA, halo2_dir_path: Option<&str>, circom_file_path: Option<&str>, circom_template_name: Option<&str>, - gen_substrs: Option, -) { - let mut decomposed_regex_config: DecomposedRegexConfig = - serde_json::from_reader(File::open(decomposed_regex_path).unwrap()).unwrap(); - let regex_and_dfa = decomposed_regex_config - .to_regex_and_dfa() - .expect("failed to convert the decomposed regex to dfa"); - let gen_substrs = gen_substrs.unwrap_or(true); - + num_public_parts: usize, + gen_substrs: bool, +) -> Result<(), CompilerError> { if let Some(halo2_dir_path) = halo2_dir_path { let halo2_dir_path = PathBuf::from(halo2_dir_path); let allstr_file_path = halo2_dir_path.join("allstr.txt"); - let mut num_public_parts = 0usize; - for part in decomposed_regex_config.parts.iter() { - if part.is_public { - num_public_parts += 1; - } - } let substr_file_paths = (0..num_public_parts) .map(|idx| halo2_dir_path.join(format!("substr_{}.txt", idx))) .collect_vec(); - regex_and_dfa - .gen_halo2_tables(&allstr_file_path, &substr_file_paths, gen_substrs) - .expect("failed to generate halo2 tables"); + + gen_halo2_tables( + regex_and_dfa, + &allstr_file_path, + &substr_file_paths, + gen_substrs, + )?; } if let Some(circom_file_path) = circom_file_path { let circom_file_path = PathBuf::from(circom_file_path); let circom_template_name = circom_template_name .expect("circom template name must be specified if circom file path is specified"); - regex_and_dfa - .gen_circom(&circom_file_path, &circom_template_name, gen_substrs) - .expect("failed to generate circom"); + + gen_circom_template( + regex_and_dfa, + &circom_file_path, + &circom_template_name, + gen_substrs, + )?; } -} + Ok(()) +} + +/// Generates outputs from a decomposed regex configuration file. +/// +/// # Arguments +/// +/// * `decomposed_regex_path` - The path to the decomposed regex configuration file. +/// * `halo2_dir_path` - An optional path to the directory for Halo2 output files. +/// * `circom_file_path` - An optional path to the Circom output file. +/// * `circom_template_name` - An optional name for the Circom template. +/// * `gen_substrs` - An optional boolean indicating whether to generate substrings. +/// +/// # Returns +/// +/// A `Result` indicating success or a `CompilerError`. +pub fn gen_from_decomposed( + decomposed_regex_path: &str, + halo2_dir_path: Option<&str>, + circom_file_path: Option<&str>, + circom_template_name: Option<&str>, + gen_substrs: Option, +) -> Result<(), CompilerError> { + let mut decomposed_regex_config: DecomposedRegexConfig = + serde_json::from_reader(File::open(decomposed_regex_path)?)?; + let gen_substrs = gen_substrs.unwrap_or(false); + + let regex_and_dfa = get_regex_and_dfa(&mut decomposed_regex_config)?; + + let num_public_parts = decomposed_regex_config + .parts + .iter() + .filter(|part| part.is_public) + .count(); + + generate_outputs( + ®ex_and_dfa, + halo2_dir_path, + circom_file_path, + circom_template_name, + num_public_parts, + gen_substrs, + )?; + + Ok(()) +} + +/// Generates outputs from a raw regex string and optional substring definitions. +/// +/// # Arguments +/// +/// * `raw_regex` - The raw regex string. +/// * `substrs_json_path` - An optional path to the JSON file containing substring definitions. +/// * `halo2_dir_path` - An optional path to the directory for Halo2 output files. +/// * `circom_file_path` - An optional path to the Circom output file. +/// * `template_name` - An optional name for the Circom template. +/// * `gen_substrs` - An optional boolean indicating whether to generate substrings. +/// +/// # Returns +/// +/// A `Result` indicating success or a `CompilerError`. pub fn gen_from_raw( raw_regex: &str, substrs_json_path: Option<&str>, @@ -175,81 +154,22 @@ pub fn gen_from_raw( circom_file_path: Option<&str>, template_name: Option<&str>, gen_substrs: Option, -) { - let substrs_defs_json = if let Some(substrs_json_path) = substrs_json_path { - let substrs_json_path = PathBuf::from(substrs_json_path); - let substrs_defs_json: SubstrsDefsJson = - serde_json::from_reader(File::open(substrs_json_path).unwrap()).unwrap(); - substrs_defs_json - } else { - SubstrsDefsJson { - transitions: vec![vec![]], - } - }; +) -> Result<(), CompilerError> { + let substrs_defs_json = load_substring_definitions_json(substrs_json_path)?; let num_public_parts = substrs_defs_json.transitions.len(); - let regex_and_dfa = RegexAndDFA::from_regex_str_and_substr_defs(raw_regex, substrs_defs_json) - .expect("failed to convert the raw regex and state transitions to dfa"); - let gen_substrs = gen_substrs.unwrap_or(true); - - if let Some(halo2_dir_path) = halo2_dir_path { - let halo2_dir_path = PathBuf::from(halo2_dir_path); - let allstr_file_path = halo2_dir_path.join("allstr.txt"); - let substr_file_paths = (0..num_public_parts) - .map(|idx| halo2_dir_path.join(format!("substr_{}.txt", idx))) - .collect_vec(); - regex_and_dfa - .gen_halo2_tables(&allstr_file_path, &substr_file_paths, gen_substrs) - .expect("failed to generate halo2 tables"); - } - if let Some(circom_file_path) = circom_file_path { - let circom_file_path = PathBuf::from(circom_file_path); - let template_name = template_name - .expect("circom template name must be specified if circom file path is specified"); - regex_and_dfa - .gen_circom(&circom_file_path, &template_name, gen_substrs) - .expect("failed to generate circom"); - } -} + let regex_and_dfa = create_regex_and_dfa_from_str_and_defs(raw_regex, substrs_defs_json)?; -pub(crate) fn get_accepted_state(dfa_val: &DFAGraph) -> Option { - for i in 0..dfa_val.states.len() { - if dfa_val.states[i].r#type == "accept" { - return Some(i as usize); - } - } - None -} + let gen_substrs = gen_substrs.unwrap_or(true); -pub(crate) fn get_max_state(dfa_val: &DFAGraph) -> usize { - let mut max_state = 0; - for (_i, val) in dfa_val.states.iter().enumerate() { - if val.state > max_state { - max_state = val.state; - } - } - max_state -} + generate_outputs( + ®ex_and_dfa, + halo2_dir_path, + circom_file_path, + template_name, + num_public_parts, + gen_substrs, + )?; -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_gen_from_decomposed() { - let project_root = PathBuf::new().join(env!("CARGO_MANIFEST_DIR")); - let decomposed_regex_path = project_root.join("../circom/circuits/common/subject_all.json"); - let circom_file_path = - project_root.join("../circom/circuits/common/subject_all_regex.circom"); - let circom_template_name = Some("SubjectAllRegex"); - let gen_substrs = Some(true); - - let _result = gen_from_decomposed( - decomposed_regex_path.to_str().unwrap(), - None, - Some(circom_file_path.to_str().unwrap()), - circom_template_name.map(|s| s), - gen_substrs, - ); - } + Ok(()) } diff --git a/packages/compiler/src/regex.rs b/packages/compiler/src/regex.rs index e3a9d28..dffde5e 100644 --- a/packages/compiler/src/regex.rs +++ b/packages/compiler/src/regex.rs @@ -1,491 +1,1186 @@ -use super::CompilerError; -use crate::{DFAGraph, DFAState, DecomposedRegexConfig, RegexAndDFA, RegexPartConfig, SubstrsDefs}; +use crate::{ + errors::CompilerError, + structs::{ + DFAGraph, DFAGraphInfo, DFAStateInfo, DFAStateNode, RegexAndDFA, RegexPartConfig, + SubstringDefinitions, SubstringDefinitionsJson, + }, + DecomposedRegexConfig, +}; use regex::Regex; -use regex_automata::dfa::{dense::DFA, StartKind}; -use std::collections::{BTreeMap, BTreeSet}; +use regex_automata::dfa::{ + dense::{Config, DFA}, + StartKind, +}; +use std::{ + collections::{BTreeMap, BTreeSet, VecDeque}, + num::ParseIntError, +}; -#[derive(Debug, Clone)] -struct DFAInfoState { - typ: String, - source: usize, - edges: BTreeMap, +/// Creates a DFA configuration with specific settings. +/// +/// # Returns +/// +/// A `Config` object with minimization, anchored start, no byte classes, and acceleration enabled. +fn create_dfa_config() -> Config { + DFA::config() + .minimize(true) + .start_kind(StartKind::Anchored) + .byte_classes(false) + .accelerate(true) } -#[derive(Debug)] -struct DFAGraphInfo { - states: Vec, +/// Finds the index of the first caret (^) in a regex string that is not inside parentheses. +/// +/// # Arguments +/// +/// * `regex` - A string slice containing the regex pattern. +/// +/// # Returns +/// +/// An `Option` containing the index of the caret if found, or `None` if not found. +fn find_caret_index(regex: &str) -> Option { + let regex_bytes = regex.as_bytes(); + let mut is_in_parenthesis = false; + let mut caret_found = false; + let mut idx = 0; + + while idx < regex_bytes.len() { + match regex_bytes[idx] { + b'\\' => { + idx += 2; + } + b'(' => { + is_in_parenthesis = true; + idx += 1; + } + b'[' => { + idx += 2; + } + b')' => { + debug_assert!(is_in_parenthesis, "Unmatched parenthesis"); + is_in_parenthesis = false; + idx += 1; + if caret_found { + break; + } + } + b'^' => { + caret_found = true; + idx += 1; + if !is_in_parenthesis { + break; + } + } + _ => { + idx += 1; + } + } + } + + if caret_found { + Some(idx) + } else { + None + } } -fn parse_dfa_output(output: &str) -> DFAGraphInfo { - let mut dfa_info = DFAGraphInfo { states: Vec::new() }; +/// Processes the caret (^) in a regex, splitting it into two parts if necessary. +/// +/// # Arguments +/// +/// * `decomposed_regex` - A mutable reference to a `DecomposedRegexConfig`. +/// +/// # Returns +/// +/// A `Result` containing an `Option` with the caret position, or a `CompilerError`. +fn process_caret_in_regex( + decomposed_regex: &mut DecomposedRegexConfig, +) -> Result, CompilerError> { + let caret_position = find_caret_index(&decomposed_regex.parts[0].regex_def); - let re = Regex::new(r"\*?(\d+): ((.+?) => (\d+),?)+").unwrap(); - for captures in re.captures_iter(output) { - let src = captures[1].parse::().unwrap(); - let mut state = DFAInfoState { + if let Some(index) = caret_position { + let caret_regex = decomposed_regex.parts[0].regex_def[0..index].to_string(); + decomposed_regex.parts.push_front(RegexPartConfig { + is_public: false, + regex_def: caret_regex, + }); + decomposed_regex.parts[1].regex_def = + decomposed_regex.parts[1].regex_def[index..].to_string(); + } + + Ok(caret_position) +} + +/// Validates the end anchor ($) in a regex part. +/// +/// # Arguments +/// +/// * `decomposed_regex` - A reference to a `DecomposedRegexConfig`. +/// * `idx` - The index of the current regex part. +/// * `regex` - A reference to the current `RegexPartConfig`. +/// +/// # Returns +/// +/// A `Result` containing a boolean indicating if the part has a valid end anchor, or a `CompilerError`. +fn validate_end_anchor( + decomposed_regex: &DecomposedRegexConfig, + idx: usize, + regex: &RegexPartConfig, +) -> Result { + let is_last_part = idx == decomposed_regex.parts.len() - 1; + let ends_with_dollar = regex.regex_def.ends_with('$'); + + if ends_with_dollar && !is_last_part { + return Err(CompilerError::GenericError( + "Invalid regex, $ can only be at the end of the regex".to_string(), + )); + } + + Ok(is_last_part && ends_with_dollar) +} + +/// Parses DFA states from a string output and populates a `DFAGraphInfo` structure. +/// +/// # Arguments +/// +/// * `output` - A string slice containing the DFA state information. +/// * `dfa_info` - A mutable reference to a `DFAGraphInfo` to be populated with parsed states. +/// +/// # Returns +/// +/// A `Result` containing `()` if parsing is successful, or a `CompilerError` if parsing fails. +/// +/// # Function Behavior +/// +/// - Uses regex to match state definitions and transitions in the input string. +/// - Iterates over state matches, creating `DFAStateInfo` objects for each state. +/// - Parses transitions for each state and adds them to the state's edges. +/// - Populates `dfa_info.states` with the parsed states. +fn parse_states(output: &str, dfa_info: &mut DFAGraphInfo) -> Result<(), CompilerError> { + let state_re = Regex::new(r"\*?(\d+): ((.+?) => (\d+),?)+")?; + let transition_re = Regex::new( + r"\s+[^=]+\s*=>\s*(\d+)+\s*|\s+=+\s*=>\s*(\d+)+|\s+=-[^=]+=>\s*\s*(\d+)+\s*|\s+[^=]+-=\s*=>\s*(\d+)+\s*", + )?; + + for captures in state_re.captures_iter(output) { + let src = captures[1] + .parse::() + .map_err(|_| CompilerError::ParseError("Failed to parse state ID".to_string()))?; + + let mut state = DFAStateInfo { source: src, - typ: String::new(), + typ: if captures[0].starts_with('*') { + "accept".to_string() + } else { + String::new() + }, edges: BTreeMap::new(), }; - if &captures[0][0..1] == "*" { - state.typ = String::from("accept"); - } - for transition in Regex::new( - r"\s+[^=]+\s*=>\s*(\d+)+\s*|\s+=+\s*=>\s*(\d+)+|\s+=-[^=]+=>\s*\s*(\d+)+\s*|\s+[^=]+-=\s*=>\s*(\d+)+\s*" - ) - .unwrap() - .captures_iter(&captures[0].to_string()) { - let trimmed_transition = transition[0].trim(); - let transition_vec = trimmed_transition.split("=>").collect::>(); - let mut transition_vec_iter = transition_vec.iter(); - let mut src = transition_vec_iter.next().unwrap().trim().to_string(); - if - src.len() > 2 && - src.chars().nth(2).unwrap() == '\\' && - !(src.chars().nth(3).unwrap() == 'x') - { - src = format!("{}{}", &src[0..2], &src[3..]); - } - let dst = transition_vec_iter.next().unwrap().trim(); - state.edges.insert(src, dst.parse::().unwrap()); + + for transition in transition_re.captures_iter(&captures[0]) { + parse_transition(&mut state, &transition[0])?; } + dfa_info.states.push(state); } - let mut eoi_pointing_states = BTreeSet::new(); + Ok(()) +} + +/// Parses a single transition from a string and adds it to the DFA state. +/// +/// # Arguments +/// +/// * `state` - A mutable reference to the `DFAStateInfo` to which the transition will be added. +/// * `transition` - A string slice containing the transition information. +/// +/// # Returns +/// +/// A `Result` containing `()` if parsing is successful, or a `CompilerError` if parsing fails. +/// +/// # Function Behavior +/// +/// - Splits the transition string into source and destination parts. +/// - Processes the source string to handle special character cases. +/// - Parses the destination as a usize. +/// - Adds the parsed transition to the state's edges. +fn parse_transition(state: &mut DFAStateInfo, transition: &str) -> Result<(), CompilerError> { + let parts: Vec<&str> = transition.split("=>").collect(); + if parts.len() != 2 { + return Err(CompilerError::ParseError( + "Invalid transition format".to_string(), + )); + } + + let mut src = parts[0].trim().to_string(); + if src.len() > 2 && src.chars().nth(2) == Some('\\') && src.chars().nth(3) != Some('x') { + src = format!("{}{}", &src[0..2], &src[3..]); + } + + let dst = parts[1] + .trim() + .parse::() + .map_err(|_| CompilerError::ParseError("Failed to parse destination state".to_string()))?; + + state.edges.insert(src, dst); + Ok(()) +} +/// Processes EOI (End of Input) transitions in the DFA graph. +/// +/// Removes EOI transitions and marks their source states as accept states. +fn handle_eoi_transitions(dfa_info: &mut DFAGraphInfo) { for state in &mut dfa_info.states { - if let Some(eoi_target) = state.edges.get("EOI").cloned() { - eoi_pointing_states.insert(eoi_target); + if let Some(_) = state.edges.get("EOI") { state.typ = String::from("accept"); state.edges.remove("EOI"); } } +} - let start_state_re = Regex::new(r"START-GROUP\(anchored\)[\s*\w*\=>]*Text => (\d+)").unwrap(); - let start_state = start_state_re.captures_iter(output).next().unwrap()[1] - .parse::() - .unwrap(); - - // Sort states by order of appearance and rename the sources - let mut sorted_states = DFAGraphInfo { states: Vec::new() }; - let mut sorted_states_set = BTreeSet::new(); - let mut new_states = BTreeSet::new(); - new_states.insert(start_state); - while !new_states.is_empty() { - let mut next_states = BTreeSet::new(); - for state in &new_states { - if let Some(state) = dfa_info.states.iter().find(|s| s.source == *state) { - sorted_states.states.push((*state).clone()); - sorted_states_set.insert(state.source); - for (_, dst) in &state.edges { - if !sorted_states_set.contains(dst) { - next_states.insert(*dst); - } - } +/// Finds the start state in the DFA output string. +/// +/// # Arguments +/// +/// * `output` - A string slice containing the DFA output. +/// +/// # Returns +/// +/// A `Result` containing the start state ID as `usize`, or a `CompilerError` if not found. +fn find_start_state(output: &str) -> Result { + let start_state_re = Regex::new(r"START-GROUP\(anchored\)[\s*\w*\=>]*Text => (\d+)")?; + start_state_re + .captures(output) + .and_then(|cap| cap[1].parse::().ok()) + .ok_or_else(|| CompilerError::ParseError("Failed to find start state".to_string())) +} + +/// Sorts and renames states in a DFA graph, starting from a given start state. +/// +/// # Arguments +/// +/// * `dfa_info` - A reference to the original `DFAGraphInfo`. +/// * `start_state` - The ID of the start state. +/// +/// # Returns +/// +/// A new `DFAGraphInfo` with sorted and renamed states. +/// +/// # Function Behavior +/// +/// 1. Performs a Breadth-First Search (BFS) to sort states, starting from the start state. +/// 2. Creates a mapping of old state IDs to new state IDs. +/// 3. Renames states and updates their edges according to the new mapping. +fn sort_and_rename_states(dfa_info: &DFAGraphInfo, start_state: usize) -> DFAGraphInfo { + let mut sorted_states = Vec::new(); + let mut visited = BTreeSet::new(); + let mut queue = VecDeque::from([start_state]); + + // BFS to sort states + while let Some(state_id) = queue.pop_front() { + if visited.insert(state_id) { + if let Some(state) = dfa_info.states.iter().find(|s| s.source == state_id) { + sorted_states.push(state.clone()); + queue.extend(state.edges.values().filter(|&dst| !visited.contains(dst))); } } - // Check if the next_states are already in the sorted_states_set - new_states.clear(); - for state in &next_states { - if !sorted_states_set.contains(state) { - new_states.insert(*state); + } + + // Create mapping of old state IDs to new state IDs + let state_map: BTreeMap<_, _> = sorted_states + .iter() + .enumerate() + .map(|(new_id, state)| (state.source, new_id)) + .collect(); + + // Rename states and update edges + let renamed_states = sorted_states + .into_iter() + .enumerate() + .map(|(new_id, mut state)| { + state.source = new_id; + for dst in state.edges.values_mut() { + *dst = *state_map.get(dst).unwrap_or(dst); } - } + state + }) + .collect(); + + DFAGraphInfo { + states: renamed_states, + } +} + +/// Creates a mapping of special character representations to their ASCII values. +/// +/// # Returns +/// +/// A `BTreeMap` where keys are string representations of special characters, +/// and values are their corresponding ASCII byte values. +fn create_special_char_mappings() -> BTreeMap<&'static str, u8> { + [ + ("\\n", 10), + ("\\r", 13), + ("\\t", 9), + ("\\v", 11), + ("\\f", 12), + ("\\0", 0), + ("\\\"", 34), + ("\\'", 39), + ("\\", 92), + ("' '", 32), + ] + .iter() + .cloned() + .collect() +} + +/// Processes a range edge in the DFA graph, adding all characters in the range to the edge set. +/// +/// # Arguments +/// +/// * `key` - The string representation of the range transition (e.g., "a-z"). +/// * `value` - The destination state ID. +/// * `edges` - A mutable reference to the map of edges. +/// * `special_char_mappings` - A reference to the special character mappings. +/// * `re` - A reference to the compiled Regex for parsing ranges. +/// +/// # Returns +/// +/// A `Result` containing `()` if successful, or a `CompilerError` if parsing fails. +/// +/// # Function Behavior +/// +/// - Extracts start and end characters of the range using the provided regex. +/// - Parses start and end characters to their byte values. +/// - Adds all characters in the range to the edge set for the given destination state. +fn process_range_edge( + key: &str, + value: usize, + edges: &mut BTreeMap>, + special_char_mappings: &BTreeMap<&str, u8>, + re: &Regex, +) -> Result<(), CompilerError> { + let capture = re + .captures(key) + .ok_or_else(|| CompilerError::ParseError("Failed to capture range".to_string()))?; + let start_index = parse_char(&capture[1], special_char_mappings)?; + let end_index = parse_char(&capture[2], special_char_mappings)?; + let char_range: Vec = (start_index..=end_index).collect(); + + edges + .entry(value) + .or_insert_with(BTreeSet::new) + .extend(char_range); + Ok(()) +} + +/// Processes a single character edge in the DFA graph. +/// +/// # Arguments +/// +/// * `key` - The string representation of the character. +/// * `value` - The destination state ID. +/// * `edges` - A mutable reference to the map of edges. +/// * `special_char_mappings` - A reference to the special character mappings. +/// +/// # Returns +/// +/// A `Result` containing `()` if successful, or a `CompilerError` if parsing fails. +/// +/// # Function Behavior +/// +/// - Parses the character to its byte value. +/// - Adds the byte to the edge set for the given destination state. +fn process_single_edge( + key: &str, + value: usize, + edges: &mut BTreeMap>, + special_char_mappings: &BTreeMap<&str, u8>, +) -> Result<(), CompilerError> { + let index = parse_char(key, special_char_mappings)?; + edges + .entry(value) + .or_insert_with(BTreeSet::new) + .insert(index); + Ok(()) +} + +/// Processes an edge in the DFA graph, handling both range and single character transitions. +/// +/// # Arguments +/// +/// * `key` - The string representation of the transition. +/// * `value` - The destination state ID. +/// * `edges` - A mutable reference to the map of edges. +/// * `special_char_mappings` - A reference to the special character mappings. +/// +/// # Returns +/// +/// A `Result` containing `()` if successful, or a `CompilerError` if parsing fails. +/// +/// # Function Behavior +/// +/// - Checks if the key represents a range (e.g., "a-z") or a single character. +/// - Delegates to `process_range_edge` or `process_single_edge` accordingly. +fn process_edge( + key: &str, + value: usize, + edges: &mut BTreeMap>, + special_char_mappings: &BTreeMap<&str, u8>, +) -> Result<(), CompilerError> { + let re = Regex::new(r"(.+)-(.+)")?; + if re.is_match(key) { + process_range_edge(key, value, edges, special_char_mappings, &re)?; + } else { + process_single_edge(key, value, edges, special_char_mappings)?; } + Ok(()) +} - // Rename the sources - let mut switch_states = BTreeMap::new(); - for (i, state) in sorted_states.states.iter_mut().enumerate() { - let temp = state.source; - state.source = i as usize; - switch_states.insert(temp, state.source); +/// Parses a character representation into its corresponding byte value. +/// +/// # Arguments +/// +/// * `s` - The string representation of the character. +/// * `special_char_mappings` - A reference to the special character mappings. +/// +/// # Returns +/// +/// A `Result` containing the parsed byte value, or a `CompilerError` if parsing fails. +/// +/// # Function Behavior +/// +/// - Handles hexadecimal representations (e.g., "\x41"). +/// - Looks up special characters in the provided mappings. +/// - Converts single-character strings to their byte value. +/// - Returns an error for invalid inputs. +fn parse_char(s: &str, special_char_mappings: &BTreeMap<&str, u8>) -> Result { + if s.starts_with("\\x") { + u8::from_str_radix(&s[2..], 16) + .map_err(|e: ParseIntError| CompilerError::ParseError(e.to_string())) + } else if let Some(&value) = special_char_mappings.get(s) { + Ok(value) + } else if s.len() == 1 { + Ok(s.as_bytes()[0]) + } else { + Err(CompilerError::ParseError(format!( + "Invalid character: {}", + s + ))) } +} - // Iterate over all edges of all states - for state in &mut sorted_states.states { - for (_, dst) in &mut state.edges { - *dst = switch_states.get(dst).unwrap().clone(); - } +/// Processes all edges for a state in the DFA graph. +/// +/// # Arguments +/// +/// * `state_edges` - A reference to a map of edge labels to destination state IDs. +/// +/// # Returns +/// +/// A `Result` containing a map of destination state IDs to sets of byte values, +/// or a `CompilerError` if processing fails. +/// +/// # Function Behavior +/// +/// - Creates special character mappings. +/// - Iterates over all edges, processing each one. +/// - Handles the special case of space character representation. +fn process_state_edges( + state_edges: &BTreeMap, +) -> Result>, CompilerError> { + let mut edges = BTreeMap::new(); + let special_char_mappings = create_special_char_mappings(); + + for (key, value) in state_edges { + let key = if key == "' '" { " " } else { key }; + process_edge(key, *value, &mut edges, &special_char_mappings)?; } - sorted_states + Ok(edges) } -fn dfa_to_graph(dfa_info: &DFAGraphInfo) -> DFAGraph { +/// Converts a DFA (Deterministic Finite Automaton) to a DFAGraph structure. +/// +/// # Arguments +/// +/// * `dfa` - The DFA to convert. +/// +/// # Returns +/// +/// A `Result` containing the converted `DFAGraph`, or a `CompilerError` if conversion fails. +/// +/// # Function Behavior +/// +/// 1. Converts the DFA to a string representation. +/// 2. Parses states from the string representation. +/// 3. Handles EOI (End of Input) transitions. +/// 4. Finds the start state and sorts/renames states accordingly. +/// 5. Processes edges for each state and constructs the final graph. +fn convert_dfa_to_graph(dfa: DFA>) -> Result { + let dfa_str = format!("{:?}", dfa); + + let mut dfa_info = DFAGraphInfo { states: Vec::new() }; + + parse_states(&dfa_str, &mut dfa_info)?; + + handle_eoi_transitions(&mut dfa_info); + + let start_state = find_start_state(&dfa_str)?; + dfa_info = sort_and_rename_states(&mut dfa_info, start_state); + let mut graph = DFAGraph { states: Vec::new() }; for state in &dfa_info.states { - let mut edges = BTreeMap::new(); - let key_mappings: BTreeMap<&str, u8> = [ - ("\\n", 10), - ("\\r", 13), - ("\\t", 9), - ("\\v", 11), - ("\\f", 12), - ("\\0", 0), - ("\\\"", 34), - ("\\'", 39), - ("\\", 92), - ("' '", 32), - ] - .into(); - for (key, value) in &state.edges { - let mut key: &str = key; - if key == "' '" { - key = " "; - } - let re = Regex::new(r"(.+)-(.+)").unwrap(); - if re.is_match(key) { - let capture = re.captures_iter(key).next().unwrap(); - let mut start = &capture[1]; - let start_index; - if start.starts_with("\\x") { - start = &start[2..]; - start_index = u8::from_str_radix(start, 16).unwrap(); - } else { - if key_mappings.contains_key(start) { - start_index = *key_mappings.get(start).unwrap(); - } else { - start_index = start.as_bytes()[0]; - } - } - let mut end = &capture[2]; - let end_index; - if end.starts_with("\\x") { - end = &end[2..]; - end_index = u8::from_str_radix(end, 16).unwrap(); - } else { - if key_mappings.contains_key(end) { - end_index = *key_mappings.get(end).unwrap(); - } else { - end_index = end.as_bytes()[0]; - } - } - let char_range: Vec = (start_index..=end_index).collect(); - if edges.contains_key(value) { - let edge: &mut BTreeSet = edges.get_mut(value).unwrap(); - for c in char_range { - edge.insert(c); - } - } else { - edges.insert(*value, char_range.into_iter().collect()); - } - } else { - let index; - if key.starts_with("\\x") { - key = &key[2..]; - index = u8::from_str_radix(key, 16).unwrap(); - } else { - if key_mappings.contains_key(key) { - index = *key_mappings.get(key).unwrap(); - } else { - index = key.as_bytes()[0]; - } - } - if edges.contains_key(value) { - let edge: &mut BTreeSet = edges.get_mut(value).unwrap(); - edge.insert(index); - } else { - edges.insert(*value, vec![index].into_iter().collect()); - } - } - } - - graph.states.push(DFAState { - r#type: state.typ.clone(), - edges: edges, - state: state.source, + let edges = process_state_edges(&state.edges)?; + graph.states.push(DFAStateNode { + state_type: state.typ.clone(), + state_id: state.source, + transitions: edges, }); } - graph + Ok(graph) } -fn rename_states(dfa_info: &DFAGraph, base: usize) -> DFAGraph { - let mut dfa_info = dfa_info.clone(); - // Rename the sources - let mut switch_states = BTreeMap::new(); - for (i, state) in dfa_info.states.iter_mut().enumerate() { - let temp = state.state; - state.state = i + base; - switch_states.insert(temp, state.state); +/// Modifies the DFA graph to handle the caret (^) anchor at the start of a regex. +/// +/// # Arguments +/// +/// * `graph` - A mutable reference to the DFAGraph to be modified. +/// +/// # Returns +/// +/// A `Result` containing `()` if successful, or a `CompilerError` if modification fails. +/// +/// # Function Behavior +/// +/// 1. Clears the state type of the start state. +/// 2. Finds the accept state in the graph. +/// 3. Adds a transition from the start state to the accept state with byte value 255. +fn modify_graph_for_caret(graph: &mut DFAGraph) -> Result<(), CompilerError> { + if let Some(start_state) = graph.states.get_mut(0) { + start_state.state_type.clear(); + } else { + return Err(CompilerError::GraphError( + "Start state not found".to_string(), + )); } - // Iterate over all edges of all states and rename the states - for state in &mut dfa_info.states { - let mut new_edges = BTreeMap::new(); - for (key, value) in &state.edges { - new_edges.insert(*switch_states.get(key).unwrap(), value.clone()); + let accepted_state = graph + .states + .iter() + .find(|state| state.state_type == "accept") + .ok_or_else(|| CompilerError::GraphError("Accept state not found".to_string()))? + .clone(); + + if let Some(start_state) = graph.states.get_mut(0) { + start_state + .transitions + .entry(accepted_state.state_id) + .or_insert_with(BTreeSet::new) + .insert(255u8); + } + + Ok(()) +} + +/// Creates a simple DFA graph for the caret (^) anchor. +/// +/// # Returns +/// +/// A `DFAGraph` with two states: +/// 1. Start state (id: 0) with a transition to the accept state on byte 255. +/// 2. Accept state (id: 1) with no outgoing transitions. +fn create_simple_caret_graph() -> DFAGraph { + DFAGraph { + states: vec![ + DFAStateNode { + state_type: String::new(), + state_id: 0, + transitions: BTreeMap::from([(1, BTreeSet::from([255u8]))]), + }, + DFAStateNode { + state_type: "accept".to_string(), + state_id: 1, + transitions: BTreeMap::new(), + }, + ], + } +} + +/// Handles the caret (^) anchor in a regex by modifying the DFA graph accordingly. +/// +/// # Arguments +/// +/// * `idx` - The index of the current regex part. +/// * `caret_position` - The position of the caret in the regex, if present. +/// * `regex` - The current regex part configuration. +/// * `graph` - The DFA graph to be modified. +/// +/// # Returns +/// +/// A `Result` containing `()` if successful, or a `CompilerError` if modification fails. +/// +/// # Function Behavior +/// +/// - If it's the first regex part and a caret is present: +/// - Creates a simple caret graph if the regex is just "^". +/// - Otherwise, modifies the existing graph to handle the caret. +fn handle_caret_regex( + idx: usize, + caret_position: Option, + regex: &RegexPartConfig, + graph: &mut DFAGraph, +) -> Result<(), CompilerError> { + if idx == 0 && caret_position.is_some() { + if regex.regex_def == "^" { + *graph = create_simple_caret_graph(); + } else { + modify_graph_for_caret(graph)?; } - state.edges = new_edges; } + Ok(()) +} + +/// Renames the states in a DFA graph, offsetting their IDs by a given base value. +/// +/// # Arguments +/// +/// * `dfa_graph` - The original DFA graph. +/// * `base` - The base offset for new state IDs. +/// +/// # Returns +/// +/// A new `DFAGraph` with renamed states. +/// +/// # Function Behavior +/// +/// 1. Creates a mapping of old state IDs to new state IDs. +/// 2. Constructs a new graph with updated state IDs and transitions. +/// 3. Preserves other properties of each state. +fn rename_states(dfa_graph: &DFAGraph, base: usize) -> DFAGraph { + let state_id_mapping: BTreeMap<_, _> = dfa_graph + .states + .iter() + .enumerate() + .map(|(i, state)| (state.state_id, i + base)) + .collect(); - dfa_info + DFAGraph { + states: dfa_graph + .states + .iter() + .enumerate() + .map(|(i, state)| DFAStateNode { + state_id: i + base, + transitions: state + .transitions + .iter() + .map(|(key, value)| { + ( + *state_id_mapping.get(key).expect("State not found"), + value.clone(), + ) + }) + .collect(), + ..state.clone() + }) + .collect(), + } } -fn add_dfa(net_dfa: &DFAGraph, graph: &DFAGraph) -> DFAGraph { - if net_dfa.states.is_empty() { - return graph.clone(); +/// Collects accepting states from a DFA graph and their state IDs. +/// +/// # Arguments +/// +/// * `net_dfa` - A reference to the DFA graph. +/// +/// # Returns +/// +/// A tuple containing: +/// 1. A vector of references to accepting DFAStateNodes. +/// 2. A BTreeSet of state IDs of the accepting states. +fn collect_accepting_states(net_dfa: &DFAGraph) -> (Vec<&DFAStateNode>, BTreeSet) { + let mut accepting_states = Vec::new(); + let mut substring_starts = BTreeSet::new(); + + for state in &net_dfa.states { + if state.state_type == "accept" { + accepting_states.push(state); + substring_starts.insert(state.state_id); + } + } + + (accepting_states, substring_starts) +} + +/// Collects all edges in the DFA graph. +/// +/// # Arguments +/// +/// * `graph` - A reference to the DFAGraph. +/// +/// # Returns +/// +/// A `BTreeSet` containing tuples of (from_state, to_state) representing all edges in the graph. +fn collect_public_edges(graph: &DFAGraph) -> BTreeSet<(usize, usize)> { + graph + .states + .iter() + .flat_map(|state| { + state + .transitions + .keys() + .map(move |&key| (state.state_id, key)) + }) + .collect() +} + +/// Collects the state IDs of all accepting states in the DFA graph. +/// +/// # Arguments +/// +/// * `graph` - A reference to the DFAGraph. +/// +/// # Returns +/// +/// A `BTreeSet` containing the state IDs of all accepting states. +fn collect_substr_ends(graph: &DFAGraph) -> BTreeSet { + graph + .states + .iter() + .filter(|state| state.state_type == "accept") + .map(|state| state.state_id) + .collect() +} + +/// Updates the public edges of a DFA graph when merging multiple DFAs. +/// +/// This function modifies the set of public edges by replacing edges connected +/// to the maximum state index with edges connected to accepting states. +/// +/// # Arguments +/// +/// * `public_edges` - A mutable reference to a BTreeSet of (from, to) state pairs representing public edges. +/// * `max_state_index` - The maximum state index in the current DFA before merging. +/// * `accepting_states` - A slice of references to DFAStateNode representing accepting states. +/// +/// # Notes +/// +/// This function assumes that `max_state_index` represents a boundary between +/// two DFAs being merged, and updates edges accordingly. +fn update_public_edges( + public_edges: &mut BTreeSet<(usize, usize)>, + max_state_index: usize, + accepting_states: &[&DFAStateNode], +) { + if max_state_index == 0 { + return; } - let mut net_dfa = net_dfa.clone(); - - let start_state = graph.states.iter().next().unwrap(); - - for state in &mut net_dfa.states { - if state.r#type == "accept" { - for (k, v) in &start_state.edges { - for edge_value in v { - for (_, v) in &mut state.edges { - if v.contains(edge_value) { - v.retain(|val| val != edge_value); - } - } + + let edges_to_update: Vec<_> = public_edges + .iter() + .filter(|&&(from, to)| (from == max_state_index || to == max_state_index)) + .cloned() + .collect(); + + for (from, to) in edges_to_update { + public_edges.remove(&(from, to)); + + if from == max_state_index && to == max_state_index { + for &accept_from in accepting_states { + for &accept_to in accepting_states { + public_edges.insert((accept_from.state_id, accept_to.state_id)); } - state.edges.insert(*k, v.clone()); } - state.r#type = "".to_string(); - if start_state.r#type == "accept" { - state.r#type = "accept".to_string(); + } else if from == max_state_index { + for &accept_state in accepting_states { + public_edges.insert((accept_state.state_id, to)); + } + } else if to == max_state_index { + for &accept_state in accepting_states { + public_edges.insert((from, accept_state.state_id)); } } } +} + +/// Processes a public regex part and updates the DFA graph accordingly. +/// +/// # Arguments +/// +/// * `regex` - A reference to the RegexPartConfig being processed. +/// * `net_dfa` - A reference to the cumulative DFAGraph built so far. +/// * `graph` - A reference to the DFAGraph for the current regex part. +/// * `previous_max_state_id` - The maximum state ID from the previous DFA. +/// +/// # Returns +/// +/// A tuple containing: +/// 1. A BTreeSet of public edges (as pairs of state IDs). +/// 2. A tuple of BTreeSets representing substring starts and ends. +fn process_public_regex( + regex: &RegexPartConfig, + net_dfa: &DFAGraph, + graph: &DFAGraph, + previous_max_state_id: usize, +) -> (BTreeSet<(usize, usize)>, (BTreeSet, BTreeSet)) { + if !regex.is_public { + return (BTreeSet::new(), (BTreeSet::new(), BTreeSet::new())); + } + + let (accepting_states, substring_starts) = collect_accepting_states(net_dfa); + let mut public_edges = collect_public_edges(graph); + let substring_ends = collect_substr_ends(graph); - for state in &graph.states { - if state.state != start_state.state { - net_dfa.states.push(state.clone()); + update_public_edges(&mut public_edges, previous_max_state_id, &accepting_states); + + (public_edges, (substring_starts, substring_ends)) +} + +/// Merges the edges from a source state into a target state, removing conflicting edges. +/// +/// # Arguments +/// +/// * `target_state` - A mutable reference to the DFAStateNode receiving the merged edges. +/// * `source_state` - A reference to the DFAStateNode providing the edges to be merged. +fn merge_edges(target_state: &mut DFAStateNode, source_state: &DFAStateNode) { + for (k, v) in &source_state.transitions { + for edge_value in v { + target_state.transitions.values_mut().for_each(|values| { + values.retain(|val| val != edge_value); + }); } + target_state.transitions.insert(*k, v.clone()); + } +} + +/// Updates the state type of a target state based on the source state. +/// +/// # Arguments +/// +/// * `target_state` - A mutable reference to the DFAStateNode being updated. +/// * `source_state` - A reference to the DFAStateNode providing the new state type. +fn update_state_type(target_state: &mut DFAStateNode, source_state: &DFAStateNode) { + target_state.state_type.clear(); + if source_state.state_type == "accept" { + target_state.state_type = "accept".to_string(); + } +} + +/// Processes an accept state by merging edges and updating its state type. +/// +/// # Arguments +/// +/// * `accept_state` - A mutable reference to the accepting DFAStateNode being processed. +/// * `start_state` - A reference to the start DFAStateNode of the graph being merged. +fn process_accept_state(accept_state: &mut DFAStateNode, start_state: &DFAStateNode) { + merge_edges(accept_state, start_state); + update_state_type(accept_state, start_state); +} + +/// Adds a new DFA graph to an existing net DFA graph. +/// +/// # Arguments +/// +/// * `net_dfa` - A reference to the existing DFAGraph. +/// * `graph` - A reference to the new DFAGraph being added. +/// +/// # Returns +/// +/// A new DFAGraph that combines the existing net DFA and the new graph. +/// +/// # Panics +/// +/// Panics if the new graph has no states. +fn add_dfa(net_dfa: &DFAGraph, graph: &DFAGraph) -> DFAGraph { + if net_dfa.states.is_empty() { + return graph.clone(); } - net_dfa + let mut new_dfa = net_dfa.clone(); + let start_state = graph.states.first().expect("Graph has no states"); + + new_dfa + .states + .iter_mut() + .filter(|state| state.state_type == "accept") + .for_each(|state| process_accept_state(state, start_state)); + + new_dfa.states.extend( + graph + .states + .iter() + .filter(|state| state.state_id != start_state.state_id) + .cloned(), + ); + + new_dfa } -pub fn regex_and_dfa( +/// Constructs a RegexAndDFA structure from a decomposed regex configuration. +/// +/// This function processes each part of the decomposed regex, builds individual DFAs, +/// and combines them into a single DFA graph. It also handles special cases like +/// caret (^) and end anchor ($) in the regex. +/// +/// # Arguments +/// +/// * `decomposed_regex` - A mutable reference to a DecomposedRegexConfig. +/// +/// # Returns +/// +/// A Result containing a RegexAndDFA structure if successful, or a CompilerError if an error occurs. +pub(crate) fn get_regex_and_dfa( decomposed_regex: &mut DecomposedRegexConfig, ) -> Result { - let mut config = DFA::config().minimize(true); - config = config.start_kind(StartKind::Anchored); - config = config.byte_classes(false); - config = config.accelerate(true); - - let mut net_dfa = DFAGraph { states: Vec::new() }; - let mut substr_endpoints_array = Vec::new(); - let mut substr_defs_array = Vec::new(); - - let caret_regex_index = { - let first_regex = decomposed_regex.parts[0].regex_def.as_bytes(); - let mut is_in_parenthesis = false; - let mut caret_found = false; - let mut idx = 0; - while idx < first_regex.len() { - let byte = first_regex[idx]; - if byte == b'\\' { - idx += 2; - } else if byte == b'(' { - is_in_parenthesis = true; - idx += 1; - } else if byte == b'[' { - idx += 2; - } else if byte == b')' { - debug_assert!(is_in_parenthesis, "Unmatched parenthesis"); - is_in_parenthesis = false; - idx += 1; - if caret_found { - break; - } - } else if byte == b'^' { - caret_found = true; - idx += 1; - if !is_in_parenthesis { - break; - } - } else { - idx += 1; - } - } + let mut net_dfa_graph = DFAGraph { states: Vec::new() }; + let mut substring_ranges_array = Vec::new(); + let mut substring_boundaries_array = Vec::new(); - if caret_found { - Some(idx) - } else { - None - } - }; - if let Some(index) = caret_regex_index { - let caret_regex = decomposed_regex.parts[0].regex_def[0..index].to_string(); - decomposed_regex.parts.push_front(RegexPartConfig { - is_public: false, - regex_def: caret_regex, - }); - decomposed_regex.parts[1].regex_def = - decomposed_regex.parts[1].regex_def[index..].to_string(); - } + let config = create_dfa_config(); + + let caret_position = process_caret_in_regex(decomposed_regex)?; let mut end_anchor = false; - for (idx, regex) in decomposed_regex.parts.iter().enumerate() { - end_anchor = match decomposed_regex.parts.len() { - 1 => regex.regex_def.ends_with("$"), - 2 => { - if idx == 0 && regex.regex_def.ends_with("$") { - return Err(CompilerError::GenericError( - "Invalid regex, $ can only be at the end of the regex".to_string(), - )); - } - idx == 1 && regex.regex_def.ends_with("$") - } - _ => match idx { - 0 | _ if idx == decomposed_regex.parts.len() - 1 => { - if regex.regex_def.ends_with("$") { - if idx == 0 { - return Err(CompilerError::GenericError( - "Invalid regex, $ can only be at the end of the regex".to_string(), - )); - } - true - } else { - false - } - } - _ => false, - }, - }; - let re = DFA::builder() - .configure(config.clone()) - .build(&format!(r"^({})$", regex.regex_def.as_str())); - if re.is_err() { - return Err(CompilerError::GenericError(format!( - "Failed to build DFA for regex: \"{}\", please check your regex", - regex.regex_def - ))); - } - let re_str = format!("{:?}", re.unwrap()); - // println!("{:?}", re_str); - let mut graph = dfa_to_graph(&parse_dfa_output(&re_str)); - if idx == 0 && caret_regex_index.is_some() { - if regex.regex_def.as_str() == "^" { - graph = DFAGraph { - states: vec![ - DFAState { - r#type: "".to_string(), - edges: BTreeMap::from([(1, BTreeSet::from([255u8]))]), - state: 0, - }, - DFAState { - r#type: "accept".to_string(), - edges: BTreeMap::new(), - state: 1, - }, - ], - }; - } else { - graph.states[0].r#type = "".to_string(); - let accepted_state = graph - .states - .iter() - .find(|state| state.r#type == "accept") - .unwrap() - .clone(); - if let Some(edge) = graph.states[0].edges.get_mut(&accepted_state.state) { - edge.insert(255u8); - } else { - graph.states[0] - .edges - .insert(accepted_state.state, BTreeSet::from([255u8])); - } - } - } - // println!("{:?}", graph); - // Find max state in net_dfa - let mut max_state_index = 0; - for state in net_dfa.states.iter() { - if state.state > max_state_index { - max_state_index = state.state; - } - } + for (i, regex) in decomposed_regex.parts.iter().enumerate() { + end_anchor = validate_end_anchor(decomposed_regex, i, regex)?; - graph = rename_states(&graph, max_state_index); + let dfa = DFA::builder() + .configure(config.clone()) + .build(&format!(r"^({})$", regex.regex_def.as_str())) + .map_err(|err| CompilerError::BuildError { + regex: regex.regex_def.clone(), + source: err, + })?; - if regex.is_public { - let mut substr_starts = BTreeSet::new(); - let mut substr_ends = BTreeSet::new(); - - let mut accepting_states = Vec::new(); - for state in &net_dfa.states { - if state.r#type == "accept" { - accepting_states.push(state); - substr_starts.insert(state.state); - } - } + let mut dfa_graph = convert_dfa_to_graph(dfa)?; - let mut public_edges = BTreeSet::new(); - for state in &graph.states { - for (key, _) in &state.edges { - public_edges.insert((state.state, *key)); - } - } + handle_caret_regex(i, caret_position, regex, &mut dfa_graph)?; - for state in &graph.states { - if state.r#type == "accept" { - substr_ends.insert(state.state); - } - } + let max_state_index = net_dfa_graph + .states + .iter() + .map(|state| state.state_id) + .max() + .unwrap_or(0); - if max_state_index != 0 { - for public_edge in &public_edges.clone() { - if public_edge.0 == max_state_index && public_edge.1 == max_state_index { - public_edges.remove(&(public_edge.0, public_edge.1)); - for accept_state in &accepting_states { - for accept_state_ in &accepting_states { - public_edges.insert((accept_state.state, accept_state_.state)); - } - } - } else if public_edge.0 == max_state_index { - public_edges.remove(&(public_edge.0, public_edge.1)); - for accept_state in &accepting_states { - public_edges.insert((accept_state.state, public_edge.1)); - } - } else if public_edge.1 == max_state_index { - public_edges.remove(&(public_edge.0, public_edge.1)); - for accept_state in &accepting_states { - public_edges.insert((public_edge.0, accept_state.state)); - } - } - } - } + dfa_graph = rename_states(&dfa_graph, max_state_index); - substr_defs_array.push(public_edges); - substr_endpoints_array.push((substr_starts, substr_ends)); + if regex.is_public { + let (public_edges, (substr_starts, substr_ends)) = + process_public_regex(regex, &net_dfa_graph, &dfa_graph, max_state_index); + substring_ranges_array.push(public_edges); + substring_boundaries_array.push((substr_starts, substr_ends)); } - net_dfa = add_dfa(&net_dfa, &graph); + net_dfa_graph = add_dfa(&net_dfa_graph, &dfa_graph); } - println!("{:?}", net_dfa); - let mut regex_str = String::new(); - for regex in decomposed_regex.parts.iter() { - regex_str += ®ex.regex_def; - } + let regex_str = decomposed_regex + .parts + .iter() + .map(|regex| regex.regex_def.as_str()) + .collect::(); Ok(RegexAndDFA { - regex_str: regex_str, - dfa_val: net_dfa, - end_anchor, - substrs_defs: SubstrsDefs { - substr_defs_array: substr_defs_array, - substr_endpoints_array: Some(substr_endpoints_array), + regex_pattern: regex_str, + dfa: net_dfa_graph, + has_end_anchor: end_anchor, + substrings: SubstringDefinitions { + substring_ranges: substring_ranges_array, + substring_boundaries: Some(substring_boundaries_array), }, }) } -pub fn dfa_from_regex_str(regex: &str) -> DFAGraph { - let mut config = DFA::config().minimize(true); - config = config.start_kind(StartKind::Anchored); - config = config.byte_classes(false); - config = config.accelerate(true); - let re = DFA::builder() +/// Creates a DFA graph from a regex string. +/// +/// # Arguments +/// +/// * `regex` - A string slice containing the regex pattern. +/// +/// # Returns +/// +/// A `Result` containing a `DFAGraph` or a `CompilerError`. +fn create_dfa_graph_from_regex(regex: &str) -> Result { + let config = DFA::config() + .minimize(true) + .start_kind(StartKind::Anchored) + .byte_classes(false) + .accelerate(true); + + let dfa = DFA::builder() .configure(config) .build(&format!(r"^{}$", regex)) - .unwrap(); - let re_str = format!("{:?}", re); - let graph = dfa_to_graph(&parse_dfa_output(&re_str)); - graph + .map_err(|e| CompilerError::BuildError { + regex: regex.to_string(), + source: e, + })?; + + convert_dfa_to_graph(dfa) +} + +/// Checks if a given string matches the regex pattern represented by the DFAGraph. +/// +/// # Arguments +/// +/// * `graph` - A reference to the DFAGraph obtained from the regex. +/// * `input` - The string to check against the regex pattern. +/// +/// # Returns +/// +/// A boolean indicating whether the input string matches the regex pattern. +fn match_string_with_dfa_graph(graph: &DFAGraph, input: &str) -> bool { + let mut current_state = 0; + + for &byte in input.as_bytes() { + let current_node = &graph.states[current_state]; + + let mut next_state = None; + for (&state, char_set) in ¤t_node.transitions { + if char_set.contains(&byte) { + next_state = Some(state); + break; + } + } + + match next_state { + Some(state) => { + current_state = state; + } + None => { + return false; + } // No valid transition found, input doesn't match + } + } + + // Check if the final state is an accepting state + graph.states[current_state].state_type == "accept" +} + +/// Creates a `RegexAndDFA` from a regex string and substring definitions. +/// +/// # Arguments +/// +/// * `regex_str` - A string slice containing the regex pattern. +/// * `substrs_defs_json` - A `SubstringDefinitionsJson` object. +/// +/// # Returns +/// +/// A `Result` containing a `RegexAndDFA` or a `CompilerError`. +pub(crate) fn create_regex_and_dfa_from_str_and_defs( + regex_str: &str, + substrs_defs_json: SubstringDefinitionsJson, +) -> Result { + let dfa = create_dfa_graph_from_regex(regex_str)?; + + let substring_ranges = substrs_defs_json + .transitions + .into_iter() + .map(|transitions| { + transitions + .into_iter() + .collect::>() + }) + .collect(); + + let substrings = SubstringDefinitions { + substring_ranges, + substring_boundaries: None, + }; + + Ok(RegexAndDFA { + regex_pattern: regex_str.to_string(), + dfa, + has_end_anchor: regex_str.ends_with('$'), + substrings, + }) +} + +/// Gets the index of the accepted state in a DFA graph. +/// +/// # Arguments +/// +/// * `dfa` - A reference to a `DFAGraph`. +/// +/// # Returns +/// +/// An `Option` containing the index of the accepted state, if found. +pub(crate) fn get_accepted_state(dfa: &DFAGraph) -> Option { + dfa.states + .iter() + .position(|state| state.state_type == "accept") +} + +/// Gets the maximum state ID in a DFA graph. +/// +/// # Arguments +/// +/// * `dfa` - A reference to a `DFAGraph`. +/// +/// # Returns +/// +/// A `usize` representing the maximum state ID. +pub(crate) fn get_max_state(dfa: &DFAGraph) -> usize { + dfa.states + .iter() + .map(|state| state.state_id) + .max() + .unwrap_or_default() +} + +mod dfa_test { + use crate::regex::{create_dfa_graph_from_regex, match_string_with_dfa_graph}; + use serde::{Deserialize, Serialize}; + use std::{env, fs::File, io::BufReader, path::PathBuf}; + + #[derive(Debug, Deserialize, Serialize)] + struct RegexTestCase { + pub regex: String, + pub pass: Vec, + pub fail: Vec, + } + + #[test] + fn test_dfa_graph() { + let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + path.push("src/dfa_tests.json"); + let file = File::open(path).expect("Failed to open test cases file"); + let reader = BufReader::new(file); + let test_cases: Vec = + serde_json::from_reader(reader).expect("Failed to parse JSON"); + + for case in test_cases { + let dfa_graph = match create_dfa_graph_from_regex(&case.regex) { + Ok(graph) => graph, + Err(e) => { + panic!( + "Failed to create DFA graph for regex '{}': {:?}", + case.regex, e + ); + } + }; + + for pass_case in case.pass { + assert!( + match_string_with_dfa_graph(&dfa_graph, &pass_case), + "Positive case failed for regex '{}': '{}'", + case.regex, + pass_case + ); + } + + for fail_case in case.fail { + assert!( + !match_string_with_dfa_graph(&dfa_graph, &fail_case), + "Negative case failed for regex '{}': '{}'", + case.regex, + fail_case + ); + } + } + } } diff --git a/packages/compiler/src/structs.rs b/packages/compiler/src/structs.rs new file mode 100644 index 0000000..2fcdc78 --- /dev/null +++ b/packages/compiler/src/structs.rs @@ -0,0 +1,56 @@ +use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, BTreeSet, VecDeque}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RegexPartConfig { + pub is_public: bool, + pub regex_def: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DecomposedRegexConfig { + pub parts: VecDeque, +} + +#[derive(Debug, Clone)] +pub struct DFAStateInfo { + pub typ: String, + pub source: usize, + pub edges: BTreeMap, +} + +#[derive(Debug)] +pub struct DFAGraphInfo { + pub states: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DFAStateNode { + pub state_type: String, + pub state_id: usize, + pub transitions: BTreeMap>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DFAGraph { + pub states: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SubstringDefinitions { + pub substring_ranges: Vec>, + pub substring_boundaries: Option, BTreeSet)>>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RegexAndDFA { + pub regex_pattern: String, + pub dfa: DFAGraph, + pub has_end_anchor: bool, + pub substrings: SubstringDefinitions, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SubstringDefinitionsJson { + pub transitions: Vec>, +} diff --git a/packages/compiler/src/tests/mod.rs b/packages/compiler/src/tests/mod.rs deleted file mode 100644 index 0497cce..0000000 --- a/packages/compiler/src/tests/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod regex_to_dfa; \ No newline at end of file diff --git a/packages/compiler/src/tests/regex_to_dfa.rs b/packages/compiler/src/tests/regex_to_dfa.rs deleted file mode 100644 index b2b55ac..0000000 --- a/packages/compiler/src/tests/regex_to_dfa.rs +++ /dev/null @@ -1,94 +0,0 @@ -// use crate::js_caller::{regex_to_dfa, JsCallerError}; - -// #[cfg(test)] -// fn test_regex_to_dfa_case_1() { -// let regex = "[a-z]+"; -// let dfa = regex_to_dfa(regex).unwrap(); -// assert_eq!( -// serde_json::to_string_pretty(&dfa).unwrap(), -// r#"[ -// { -// "type": "", -// "edges": { -// "[\"a\",\"b\",\"c\",\"d\",\"e\",\"f\",\"g\",\"h\",\"i\",\"j\",\"k\",\"l\",\"m\",\"n\",\"o\",\"p\",\"q\",\"r\",\"s\",\"t\",\"u\",\"v\",\"w\",\"x\",\"y\",\"z\"]": 1 -// } -// }, -// { -// "type": "accept", -// "edges": { -// "[\"a\",\"b\",\"c\",\"d\",\"e\",\"f\",\"g\",\"h\",\"i\",\"j\",\"k\",\"l\",\"m\",\"n\",\"o\",\"p\",\"q\",\"r\",\"s\",\"t\",\"u\",\"v\",\"w\",\"x\",\"y\",\"z\"]": 1 -// } -// } -// ]"# -// ); -// } - -// #[test] -// fn test_regex_to_dfa_case_2() { -// let regex = "[a-z0-9]+"; -// let dfa = regex_to_dfa(regex).unwrap(); -// assert_eq!( -// serde_json::to_string_pretty(&dfa).unwrap(), -// r#"[ -// { -// "type": "", -// "edges": { -// "[\"0\",\"1\",\"2\",\"3\",\"4\",\"5\",\"6\",\"7\",\"8\",\"9\",\"a\",\"b\",\"c\",\"d\",\"e\",\"f\",\"g\",\"h\",\"i\",\"j\",\"k\",\"l\",\"m\",\"n\",\"o\",\"p\",\"q\",\"r\",\"s\",\"t\",\"u\",\"v\",\"w\",\"x\",\"y\",\"z\"]": 1 -// } -// }, -// { -// "type": "accept", -// "edges": { -// "[\"0\",\"1\",\"2\",\"3\",\"4\",\"5\",\"6\",\"7\",\"8\",\"9\",\"a\",\"b\",\"c\",\"d\",\"e\",\"f\",\"g\",\"h\",\"i\",\"j\",\"k\",\"l\",\"m\",\"n\",\"o\",\"p\",\"q\",\"r\",\"s\",\"t\",\"u\",\"v\",\"w\",\"x\",\"y\",\"z\"]": 1 -// } -// } -// ]"# -// ); -// } - -// #[test] -// fn test_regex_to_dfa_case_3() { -// let regex = "[a-z0-9]+@[a-z0-9]+\r\n"; -// let dfa = regex_to_dfa(regex).unwrap(); -// assert_eq!( -// serde_json::to_string_pretty(&dfa).unwrap(), -// r#"[ -// { -// "type": "", -// "edges": { -// "[\"0\",\"1\",\"2\",\"3\",\"4\",\"5\",\"6\",\"7\",\"8\",\"9\",\"a\",\"b\",\"c\",\"d\",\"e\",\"f\",\"g\",\"h\",\"i\",\"j\",\"k\",\"l\",\"m\",\"n\",\"o\",\"p\",\"q\",\"r\",\"s\",\"t\",\"u\",\"v\",\"w\",\"x\",\"y\",\"z\"]": 1 -// } -// }, -// { -// "type": "", -// "edges": { -// "[\"0\",\"1\",\"2\",\"3\",\"4\",\"5\",\"6\",\"7\",\"8\",\"9\",\"a\",\"b\",\"c\",\"d\",\"e\",\"f\",\"g\",\"h\",\"i\",\"j\",\"k\",\"l\",\"m\",\"n\",\"o\",\"p\",\"q\",\"r\",\"s\",\"t\",\"u\",\"v\",\"w\",\"x\",\"y\",\"z\"]": 1, -// "[\"@\"]": 2 -// } -// }, -// { -// "type": "", -// "edges": { -// "[\"0\",\"1\",\"2\",\"3\",\"4\",\"5\",\"6\",\"7\",\"8\",\"9\",\"a\",\"b\",\"c\",\"d\",\"e\",\"f\",\"g\",\"h\",\"i\",\"j\",\"k\",\"l\",\"m\",\"n\",\"o\",\"p\",\"q\",\"r\",\"s\",\"t\",\"u\",\"v\",\"w\",\"x\",\"y\",\"z\"]": 3 -// } -// }, -// { -// "type": "", -// "edges": { -// "[\"0\",\"1\",\"2\",\"3\",\"4\",\"5\",\"6\",\"7\",\"8\",\"9\",\"a\",\"b\",\"c\",\"d\",\"e\",\"f\",\"g\",\"h\",\"i\",\"j\",\"k\",\"l\",\"m\",\"n\",\"o\",\"p\",\"q\",\"r\",\"s\",\"t\",\"u\",\"v\",\"w\",\"x\",\"y\",\"z\"]": 3, -// "[\"\\r\"]": 4 -// } -// }, -// { -// "type": "", -// "edges": { -// "[\"\\n\"]": 5 -// } -// }, -// { -// "type": "accept", -// "edges": {} -// } -// ]"# -// ); -// } diff --git a/packages/compiler/src/wasm.rs b/packages/compiler/src/wasm.rs index dcc0634..742bbdc 100644 --- a/packages/compiler/src/wasm.rs +++ b/packages/compiler/src/wasm.rs @@ -2,47 +2,46 @@ use crate::*; use serde_wasm_bindgen::from_value; use wasm_bindgen::prelude::*; +use self::circom::gen_circom_string; + #[wasm_bindgen] #[allow(non_snake_case)] pub fn genFromDecomposed(decomposedRegexJson: &str, circomTemplateName: &str) -> String { let mut decomposed_regex_config: DecomposedRegexConfig = serde_json::from_str(decomposedRegexJson).expect("failed to parse decomposed_regex json"); - let regex_and_dfa = decomposed_regex_config - .to_regex_and_dfa() + let regex_and_dfa = get_regex_and_dfa(&mut decomposed_regex_config) .expect("failed to convert the decomposed regex to dfa"); - regex_and_dfa - .gen_circom_str(&circomTemplateName) - .expect("failed to generate circom") + gen_circom_string(®ex_and_dfa, circomTemplateName).expect("failed to generate circom") } #[wasm_bindgen] #[allow(non_snake_case)] pub fn genFromRaw(rawRegex: &str, substrsJson: &str, circomTemplateName: &str) -> String { - let substrs_defs_json: SubstrsDefsJson = + let substrs_defs_json: SubstringDefinitionsJson = serde_json::from_str(substrsJson).expect("failed to parse substrs json"); - let regex_and_dfa = RegexAndDFA::from_regex_str_and_substr_defs(rawRegex, substrs_defs_json) + let regex_and_dfa = create_regex_and_dfa_from_str_and_defs(rawRegex, substrs_defs_json) .expect("failed to convert the raw regex and state transitions to dfa"); - regex_and_dfa - .gen_circom_str(&circomTemplateName) - .expect("failed to generate circom") + gen_circom_string(®ex_and_dfa, circomTemplateName).expect("failed to generate circom") } #[wasm_bindgen] #[allow(non_snake_case)] pub fn genRegexAndDfa(decomposedRegex: JsValue) -> JsValue { - let mut decomposed_regex_config: DecomposedRegexConfig = from_value(decomposedRegex).unwrap(); - let regex_and_dfa = regex_and_dfa(&mut decomposed_regex_config).unwrap(); - let dfa_val_str = serde_json::to_string(®ex_and_dfa).unwrap(); + let mut decomposed_regex_config: DecomposedRegexConfig = + from_value(decomposedRegex).expect("failed to parse decomposed regex"); + let regex_and_dfa = get_regex_and_dfa(&mut decomposed_regex_config) + .expect("failed to convert the decomposed regex to dfa"); + let dfa_val_str = + serde_json::to_string(®ex_and_dfa).expect("failed to convert the dfa to json"); JsValue::from_str(&dfa_val_str) } #[wasm_bindgen] #[allow(non_snake_case)] pub fn genCircom(decomposedRegex: JsValue, circomTemplateName: &str) -> String { - let mut decomposed_regex_config: DecomposedRegexConfig = from_value(decomposedRegex).unwrap(); - let regex_and_dfa = regex_and_dfa(&mut decomposed_regex_config); - regex_and_dfa - .expect("failed to convert the decomposed regex to dfa") - .gen_circom_str(&circomTemplateName) - .expect("failed to generate circom") + let mut decomposed_regex_config: DecomposedRegexConfig = + from_value(decomposedRegex).expect("failed to parse decomposed regex"); + let regex_and_dfa = get_regex_and_dfa(&mut decomposed_regex_config) + .expect("failed to convert the decomposed regex to dfa"); + gen_circom_string(®ex_and_dfa, circomTemplateName).expect("failed to generate circom") }