Skip to content

Commit

Permalink
ZeroAsciiIgnoreCaseTrie (#4549)
Browse files Browse the repository at this point in the history
  • Loading branch information
sffc authored Jan 25, 2024
1 parent 78130b7 commit b12ec9a
Show file tree
Hide file tree
Showing 11 changed files with 365 additions and 13 deletions.
73 changes: 69 additions & 4 deletions experimental/zerotrie/src/builder/nonconst/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
// called LICENSE at the top level of the ICU4X source tree
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).

use core::cmp::Ordering;

use super::super::branch_meta::BranchMeta;
use super::store::NonConstLengthsStack;
use super::store::TrieBuilderStore;
use crate::builder::bytestr::ByteStr;
use crate::byte_phf::PerfectByteHashMapCacheOwned;
use crate::error::Error;
use crate::varint;
use alloc::borrow::Cow;
use alloc::vec::Vec;

/// Whether to use the perfect hash function in the ZeroTrie.
Expand All @@ -35,10 +38,19 @@ pub enum CapacityMode {
Extended,
}

/// Whether to permit strings that have inconsistent ASCII case at a node, such as "abc" and "Abc"
pub enum MixedCaseMode {
/// Allows strings regardless of case.
Allow,
/// Returns an error if a node exists with the same character in ambiguous case.
Reject,
}

pub struct ZeroTrieBuilderOptions {
pub phf_mode: PhfMode,
pub ascii_mode: AsciiMode,
pub capacity_mode: CapacityMode,
pub mixed_case_mode: MixedCaseMode,
}

/// A low-level builder for ZeroTrie. Supports all options.
Expand Down Expand Up @@ -129,13 +141,14 @@ impl<S: TrieBuilderStore> ZeroTrieBuilder<S> {
.iter()
.map(|(k, v)| (k.as_ref(), *v))
.collect::<Vec<(&[u8], usize)>>();
items.sort();
items.sort_by(|a, b| cmp_keys_values(&options, *a, *b));
let ascii_str_slice = items.as_slice();
let byte_str_slice = ByteStr::from_byte_slice_with_value(ascii_str_slice);
Self::from_sorted_tuple_slice(byte_str_slice, options)
}

/// Builds a ZeroTrie with the given items and options. Assumes that the items are sorted.
/// Builds a ZeroTrie with the given items and options. Assumes that the items are sorted,
/// except for a case-insensitive trie where the items are re-sorted.
///
/// # Panics
///
Expand All @@ -144,12 +157,27 @@ impl<S: TrieBuilderStore> ZeroTrieBuilder<S> {
items: &[(&ByteStr, usize)],
options: ZeroTrieBuilderOptions,
) -> Result<Self, Error> {
let mut items = Cow::Borrowed(items);
if matches!(options.mixed_case_mode, MixedCaseMode::Reject) {
// We need to re-sort the items with our custom comparator.
items.to_mut().sort_by(|a, b| {
cmp_keys_values(&options, (a.0.as_bytes(), a.1), (b.0.as_bytes(), b.1))
});
}
for ab in items.windows(2) {
debug_assert!(cmp_keys_values(
&options,
(ab[0].0.as_bytes(), ab[0].1),
(ab[1].0.as_bytes(), ab[1].1)
)
.is_lt());
}
let mut result = Self {
data: S::atbs_new_empty(),
phf_cache: PerfectByteHashMapCacheOwned::new_empty(),
options,
};
let total_size = result.create(items)?;
let total_size = result.create(&items)?;
debug_assert!(total_size == result.data.atbs_len());
Ok(result)
}
Expand Down Expand Up @@ -239,7 +267,15 @@ impl<S: TrieBuilderStore> ZeroTrieBuilder<S> {
if ascii_i == key_ascii && ascii_j == key_ascii {
let len = self.prepend_ascii(key_ascii)?;
current_len += len;
debug_assert!(i == new_i || i == new_i + 1);
if matches!(self.options.mixed_case_mode, MixedCaseMode::Reject) && i == new_i + 2 {
// This can happen if two strings were picked up, each with a different case
return Err(Error::MixedCase);
}
debug_assert!(
i == new_i || i == new_i + 1,
"only the exact prefix string can be picked up at this level: {}",
key_ascii
);
i = new_i;
debug_assert_eq!(j, new_j);
continue;
Expand Down Expand Up @@ -288,6 +324,20 @@ impl<S: TrieBuilderStore> ZeroTrieBuilder<S> {
};
let mut branch_metas = lengths_stack.pop_many_or_panic(total_count);
let original_keys = branch_metas.map_to_ascii_bytes();
if matches!(self.options.mixed_case_mode, MixedCaseMode::Reject) {
// Check to see if we have the same letter in two different cases
let mut seen_ascii_alpha = [false; 26];
for c in original_keys.as_const_slice().as_slice() {
if c.is_ascii_alphabetic() {
let i = (c.to_ascii_lowercase() - b'a') as usize;
if seen_ascii_alpha[i] {
return Err(Error::MixedCase);
} else {
seen_ascii_alpha[i] = true;
}
}
}
}
let use_phf = matches!(self.options.phf_mode, PhfMode::UsePhf);
let opt_phf_vec = if total_count > 15 && use_phf {
let phf_vec = self
Expand Down Expand Up @@ -379,3 +429,18 @@ impl<S: TrieBuilderStore> ZeroTrieBuilder<S> {
Ok(current_len)
}
}

fn cmp_keys_values(
options: &ZeroTrieBuilderOptions,
a: (&[u8], usize),
b: (&[u8], usize),
) -> Ordering {
if matches!(options.mixed_case_mode, MixedCaseMode::Allow) {
a.0.cmp(b.0)
} else {
let a_iter = a.0.iter().map(|x| x.to_ascii_lowercase());
let b_iter = b.0.iter().map(|x| x.to_ascii_lowercase());
Iterator::cmp(a_iter, b_iter)
}
.then_with(|| a.1.cmp(&b.1))
}
12 changes: 12 additions & 0 deletions experimental/zerotrie/src/builder/nonconst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ impl<S: ?Sized> crate::ZeroTrieSimpleAscii<S> {
phf_mode: PhfMode::BinaryOnly,
ascii_mode: AsciiMode::AsciiOnly,
capacity_mode: CapacityMode::Normal,
mixed_case_mode: MixedCaseMode::Allow,
};
}

impl<S: ?Sized> crate::ZeroAsciiIgnoreCaseTrie<S> {
pub(crate) const BUILDER_OPTIONS: ZeroTrieBuilderOptions = ZeroTrieBuilderOptions {
phf_mode: PhfMode::BinaryOnly,
ascii_mode: AsciiMode::AsciiOnly,
capacity_mode: CapacityMode::Normal,
mixed_case_mode: MixedCaseMode::Reject,
};
}

Expand All @@ -21,6 +31,7 @@ impl<S: ?Sized> crate::ZeroTriePerfectHash<S> {
phf_mode: PhfMode::UsePhf,
ascii_mode: AsciiMode::BinarySpans,
capacity_mode: CapacityMode::Normal,
mixed_case_mode: MixedCaseMode::Allow,
};
}

Expand All @@ -29,5 +40,6 @@ impl<S: ?Sized> crate::ZeroTrieExtendedCapacity<S> {
phf_mode: PhfMode::UsePhf,
ascii_mode: AsciiMode::BinarySpans,
capacity_mode: CapacityMode::Extended,
mixed_case_mode: MixedCaseMode::Allow,
};
}
3 changes: 3 additions & 0 deletions experimental/zerotrie/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,7 @@ pub enum Error {
/// The builder could not solve the perfect hash function.
#[displaydoc("Failed to solve the perfect hash function. This is rare! Please report your case to the ICU4X team.")]
CouldNotSolvePerfectHash,
/// Mixed-case data was added to a case-insensitive trie.
#[displaydoc("Mixed-case data added to case-insensitive trie")]
MixedCase,
}
1 change: 1 addition & 0 deletions experimental/zerotrie/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ mod varint;
mod zerotrie;

pub use crate::cursor::ZeroTrieSimpleAsciiCursor;
pub use crate::zerotrie::ZeroAsciiIgnoreCaseTrie;
pub use crate::zerotrie::ZeroTrie;
pub use crate::zerotrie::ZeroTrieExtendedCapacity;
pub use crate::zerotrie::ZeroTriePerfectHash;
Expand Down
59 changes: 59 additions & 0 deletions experimental/zerotrie/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,65 @@ pub fn get_ascii_bsearch_only(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize
}
}

/// Query the trie assuming all branch nodes are binary search
/// and nodes use case-insensitive matching.
pub fn get_ascii_bsearch_only_ignore_case(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize> {
loop {
let (b, x, i, search);
(b, trie) = trie.split_first()?;
let byte_type = byte_type(*b);
(x, trie) = match byte_type {
NodeType::Ascii => (0, trie),
NodeType::Span => {
debug_assert!(false, "Span node found in ASCII trie!");
return None;
}
NodeType::Value => read_varint_meta3(*b, trie),
NodeType::Branch => read_varint_meta2(*b, trie),
};
if let Some((c, temp)) = ascii.split_first() {
if matches!(byte_type, NodeType::Ascii) {
if b.to_ascii_lowercase() == c.to_ascii_lowercase() {
// Matched a byte
ascii = temp;
continue;
} else {
// Byte that doesn't match
return None;
}
}
if matches!(byte_type, NodeType::Value) {
// Value node, but not at end of string
continue;
}
// Branch node
let (x, w) = if x >= 256 { (x & 0xff, x >> 8) } else { (x, 0) };
// See comment above regarding this assertion
debug_assert!(w <= 3, "get: w > 3 but we assume w <= 3");
let w = w & 0x3;
let x = if x == 0 { 256 } else { x };
// Always use binary search
(search, trie) = trie.debug_split_at(x);
i = search
.binary_search_by_key(&c.to_ascii_lowercase(), |x| x.to_ascii_lowercase())
.ok()?;
trie = if w == 0 {
get_branch_w0(trie, i, x)
} else {
get_branch(trie, i, x, w)
};
ascii = temp;
continue;
} else {
if matches!(byte_type, NodeType::Value) {
// Value node at end of string
return Some(x);
}
return None;
}
}
}

/// Query the trie assuming branch nodes could be either binary search or PHF.
pub fn get_phf_limited(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize> {
loop {
Expand Down
46 changes: 46 additions & 0 deletions experimental/zerotrie/src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use crate::builder::bytestr::ByteStr;
use crate::zerotrie::ZeroTrieFlavor;
use crate::ZeroAsciiIgnoreCaseTrie;
use crate::ZeroTrie;
use crate::ZeroTrieExtendedCapacity;
use crate::ZeroTriePerfectHash;
Expand Down Expand Up @@ -136,6 +137,51 @@ where
}
}

impl<'de, 'data, Store> Deserialize<'de> for ZeroAsciiIgnoreCaseTrie<Store>
where
'de: 'data,
// DISCUSS: There are several possibilities for the bounds here that would
// get the job done. I could look for Deserialize, but this would require
// creating a custom Deserializer for the map case. I also considered
// introducing a new trait instead of relying on From.
Store: From<&'data [u8]> + From<Vec<u8>> + 'data,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
if deserializer.is_human_readable() {
let lm = LiteMap::<Box<ByteStr>, usize>::deserialize(deserializer)?;
ZeroAsciiIgnoreCaseTrie::try_from_serde_litemap(&lm)
.map_err(D::Error::custom)
.map(|trie| trie.convert_store())
} else {
// Note: `impl Deserialize for &[u8]` uses visit_borrowed_bytes
<&[u8]>::deserialize(deserializer)
.map(ZeroAsciiIgnoreCaseTrie::from_store)
.map(|x| x.convert_store())
}
}
}

impl<Store> Serialize for ZeroAsciiIgnoreCaseTrie<Store>
where
Store: AsRef<[u8]>,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
if serializer.is_human_readable() {
let lm = self.to_litemap();
lm.serialize(serializer)
} else {
let bytes = self.as_bytes();
bytes.serialize(serializer)
}
}
}

impl<'de, 'data, Store> Deserialize<'de> for ZeroTriePerfectHash<Store>
where
'de: 'data,
Expand Down
Loading

0 comments on commit b12ec9a

Please sign in to comment.