zerotrie/byte_phf/
mod.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5#![allow(rustdoc::private_intra_doc_links)] // doc(hidden) module
6
7//! # Byte Perfect Hash Function Internals
8//!
9//! This module contains a perfect hash function (PHF) designed for a fast, compact perfect
10//! hash over 1 to 256 nodes (bytes).
11//!
12//! The PHF uses the following variables:
13//!
14//! 1. A single parameter `p`, which is 0 in about 98% of cases.
15//! 2. A list of `N` parameters `q_t`, one per _bucket_
16//! 3. The `N` keys in an arbitrary order determined by the PHF
17//!
18//! Reading a `key` from the PHF uses the following algorithm:
19//!
20//! 1. Let `t`, the bucket index, be `f1(key, p)`.
21//! 2. Let `i`, the key index, be `f2(key, q_t)`.
22//! 3. If `key == k_i`, return `Some(i)`; else return `None`.
23//!
24//! The functions [`f1`] and [`f2`] are internal to the PHF but should remain stable across
25//! serialization versions of `ZeroTrie`. They are very fast, constant-time operations as long
26//! as `p` <= [`P_FAST_MAX`] and `q` <= [`Q_FAST_MAX`]. In practice, nearly 100% of parameter
27//! values are in the fast range.
28//!
29//! ```
30//! use zerotrie::_internal::PerfectByteHashMap;
31//!
32//! let phf_example_bytes = [
33//!     // `p` parameter
34//!     1, // `q` parameters, one for each of the N buckets
35//!     0, 0, 1, 1, // Exact keys to be compared with the input
36//!     b'e', b'a', b'c', b'g',
37//! ];
38//!
39//! let phf = PerfectByteHashMap::from_bytes(&phf_example_bytes);
40//!
41//! // The PHF returns the index of the key or `None` if not found.
42//! assert_eq!(phf.get(b'a'), Some(1));
43//! assert_eq!(phf.get(b'b'), None);
44//! assert_eq!(phf.get(b'c'), Some(2));
45//! assert_eq!(phf.get(b'd'), None);
46//! assert_eq!(phf.get(b'e'), Some(0));
47//! assert_eq!(phf.get(b'f'), None);
48//! assert_eq!(phf.get(b'g'), Some(3));
49//! ```
50
51use crate::helpers::*;
52
53#[cfg(feature = "alloc")]
54mod builder;
55#[cfg(feature = "alloc")]
56mod cached_owned;
57
58#[cfg(feature = "alloc")]
59pub use cached_owned::PerfectByteHashMapCacheOwned;
60
61/// The cutoff for the fast version of [`f1`].
62#[cfg(feature = "alloc")] // used in the builder code
63const P_FAST_MAX: u8 = 95;
64
65/// The cutoff for the fast version of [`f2`].
66const Q_FAST_MAX: u8 = 95;
67
68/// The maximum allowable value of `p`. This could be raised if found to be necessary.
69/// Values exceeding P_FAST_MAX could use a different `p` algorithm by modifying [`f1`].
70#[cfg(feature = "alloc")] // used in the builder code
71const P_REAL_MAX: u8 = P_FAST_MAX;
72
73/// The maximum allowable value of `q`. This could be raised if found to be necessary.
74#[cfg(feature = "alloc")] // used in the builder code
75const Q_REAL_MAX: u8 = 127;
76
77/// Calculates the function `f1` for the PHF. For the exact formula, please read the code.
78///
79/// When `p == 0`, the operation is a simple modulus.
80///
81/// The argument `n` is used only for taking the modulus so that the return value is
82/// in the range `[0, n)`.
83///
84/// # Examples
85///
86/// ```
87/// use zerotrie::_internal::f1;
88/// const N: u8 = 10;
89///
90/// // With p = 0:
91/// assert_eq!(0, f1(0, 0, N));
92/// assert_eq!(1, f1(1, 0, N));
93/// assert_eq!(2, f1(2, 0, N));
94/// assert_eq!(9, f1(9, 0, N));
95/// assert_eq!(0, f1(10, 0, N));
96/// assert_eq!(1, f1(11, 0, N));
97/// assert_eq!(2, f1(12, 0, N));
98/// assert_eq!(9, f1(19, 0, N));
99///
100/// // With p = 1:
101/// assert_eq!(1, f1(0, 1, N));
102/// assert_eq!(0, f1(1, 1, N));
103/// assert_eq!(2, f1(2, 1, N));
104/// assert_eq!(2, f1(9, 1, N));
105/// assert_eq!(4, f1(10, 1, N));
106/// assert_eq!(5, f1(11, 1, N));
107/// assert_eq!(1, f1(12, 1, N));
108/// assert_eq!(7, f1(19, 1, N));
109/// ```
110#[inline]
111pub fn f1(byte: u8, p: u8, n: u8) -> u8 {
112    if n == 0 {
113        byte
114    } else if p == 0 {
115        byte % n
116    } else {
117        // `p` always uses the below constant-time operation. If needed, we
118        // could add some other operation here with `p > P_FAST_MAX` to solve
119        // difficult cases if the need arises.
120        let result = byte ^ p ^ byte.wrapping_shr(p as u32);
121        result % n
122    }
123}
124
125/// Calculates the function `f2` for the PHF. For the exact formula, please read the code.
126///
127/// When `q == 0`, the operation is a simple modulus.
128///
129/// The argument `n` is used only for taking the modulus so that the return value is
130/// in the range `[0, n)`.
131///
132/// # Examples
133///
134/// ```
135/// use zerotrie::_internal::f2;
136/// const N: u8 = 10;
137///
138/// // With q = 0:
139/// assert_eq!(0, f2(0, 0, N));
140/// assert_eq!(1, f2(1, 0, N));
141/// assert_eq!(2, f2(2, 0, N));
142/// assert_eq!(9, f2(9, 0, N));
143/// assert_eq!(0, f2(10, 0, N));
144/// assert_eq!(1, f2(11, 0, N));
145/// assert_eq!(2, f2(12, 0, N));
146/// assert_eq!(9, f2(19, 0, N));
147///
148/// // With q = 1:
149/// assert_eq!(1, f2(0, 1, N));
150/// assert_eq!(0, f2(1, 1, N));
151/// assert_eq!(3, f2(2, 1, N));
152/// assert_eq!(8, f2(9, 1, N));
153/// assert_eq!(1, f2(10, 1, N));
154/// assert_eq!(0, f2(11, 1, N));
155/// assert_eq!(3, f2(12, 1, N));
156/// assert_eq!(8, f2(19, 1, N));
157/// ```
158#[inline]
159pub fn f2(byte: u8, q: u8, n: u8) -> u8 {
160    if n == 0 {
161        return byte;
162    }
163    let mut result = byte ^ q;
164    // In almost all cases, the PHF works with the above constant-time operation.
165    // However, to crack a few difficult cases, we fall back to the linear-time
166    // operation shown below.
167    for _ in Q_FAST_MAX..q {
168        result = result ^ (result << 1) ^ (result >> 1);
169    }
170    result % n
171}
172
173/// A constant-time map from bytes to unique indices.
174///
175/// Uses a perfect hash function (see module-level documentation). Does not support mutation.
176///
177/// Standard layout: P, N bytes of Q, N bytes of expected keys
178#[derive(Debug, PartialEq, Eq)]
179#[repr(transparent)]
180pub struct PerfectByteHashMap<Store: ?Sized>(Store);
181
182impl<Store> PerfectByteHashMap<Store> {
183    /// Creates an instance from a pre-existing store. See [`Self::as_bytes`].
184    #[inline]
185    pub fn from_store(store: Store) -> Self {
186        Self(store)
187    }
188}
189
190impl<Store> PerfectByteHashMap<Store>
191where
192    Store: AsRef<[u8]> + ?Sized,
193{
194    /// Gets the usize for the given byte, or `None` if it is not in the map.
195    pub fn get(&self, key: u8) -> Option<usize> {
196        let (p, buffer) = self.0.as_ref().split_first()?;
197        // Note: there are N buckets followed by N keys
198        let n_usize = buffer.len() / 2;
199        if n_usize == 0 {
200            return None;
201        }
202        let n = n_usize as u8;
203        let (qq, eks) = buffer.debug_split_at(n_usize);
204        debug_assert_eq!(qq.len(), eks.len());
205        let l1 = f1(key, *p, n) as usize;
206        let q = debug_unwrap!(qq.get(l1), return None);
207        let l2 = f2(key, *q, n) as usize;
208        let ek = debug_unwrap!(eks.get(l2), return None);
209        if *ek == key {
210            Some(l2)
211        } else {
212            None
213        }
214    }
215    /// This is called `num_items` because `len` is ambiguous: it could refer
216    /// to the number of items or the number of bytes.
217    pub fn num_items(&self) -> usize {
218        self.0.as_ref().len() / 2
219    }
220    /// Get an iterator over the keys in the order in which they are stored in the map.
221    pub fn keys(&self) -> &[u8] {
222        let n = self.num_items();
223        self.0.as_ref().debug_split_at(1 + n).1
224    }
225    /// Diagnostic function that returns `p` and the maximum value of `q`
226    #[cfg(test)]
227    pub fn p_qmax(&self) -> Option<(u8, u8)> {
228        let (p, buffer) = self.0.as_ref().split_first()?;
229        let n = buffer.len() / 2;
230        if n == 0 {
231            return None;
232        }
233        let (qq, _) = buffer.debug_split_at(n);
234        Some((*p, *qq.iter().max().unwrap()))
235    }
236    /// Returns the map as bytes. The map can be recovered with [`Self::from_store`]
237    /// or [`Self::from_bytes`].
238    pub fn as_bytes(&self) -> &[u8] {
239        self.0.as_ref()
240    }
241
242    #[cfg(all(feature = "alloc", test))]
243    pub(crate) fn check(&self) -> Result<(), (&'static str, u8)> {
244        use alloc::vec;
245        let len = self.num_items();
246        let mut seen = vec![false; len];
247        for b in 0..=255u8 {
248            let get_result = self.get(b);
249            if self.keys().contains(&b) {
250                let i = get_result.ok_or(("expected to find", b))?;
251                if seen[i] {
252                    return Err(("seen", b));
253                }
254                seen[i] = true;
255            } else if get_result.is_some() {
256                return Err(("did not expect to find", b));
257            }
258        }
259        Ok(())
260    }
261}
262
263impl PerfectByteHashMap<[u8]> {
264    /// Creates an instance from pre-existing bytes. See [`Self::as_bytes`].
265    #[inline]
266    pub fn from_bytes(bytes: &[u8]) -> &Self {
267        // Safety: Self is repr(transparent) over [u8]
268        unsafe { core::mem::transmute(bytes) }
269    }
270}
271
272impl<Store> PerfectByteHashMap<Store>
273where
274    Store: AsRef<[u8]> + ?Sized,
275{
276    /// Converts from `PerfectByteHashMap<AsRef<[u8]>>` to `&PerfectByteHashMap<[u8]>`
277    #[inline]
278    pub fn as_borrowed(&self) -> &PerfectByteHashMap<[u8]> {
279        PerfectByteHashMap::from_bytes(self.0.as_ref())
280    }
281}
282
283#[cfg(all(test, feature = "alloc"))]
284mod tests {
285    use super::*;
286    use alloc::vec::Vec;
287    extern crate std;
288
289    fn random_alphanums(seed: u64, len: usize) -> Vec<u8> {
290        use rand::seq::SliceRandom;
291        use rand::SeedableRng;
292
293        let mut bytes: Vec<u8> =
294            b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789".into();
295        let mut rng = rand_pcg::Lcg64Xsh32::seed_from_u64(seed);
296        bytes.partial_shuffle(&mut rng, len).0.into()
297    }
298
299    #[test]
300    fn test_smaller() {
301        let mut count_by_p = [0; 256];
302        let mut count_by_qmax = [0; 256];
303        for len in 1..16 {
304            for seed in 0..150 {
305                let keys = random_alphanums(seed, len);
306                let keys_str = core::str::from_utf8(&keys).unwrap();
307                let computed = PerfectByteHashMap::try_new(&keys).expect(keys_str);
308                computed
309                    .check()
310                    .unwrap_or_else(|_| panic!("{}", std::str::from_utf8(&keys).expect(keys_str)));
311                let (p, qmax) = computed.p_qmax().unwrap();
312                count_by_p[p as usize] += 1;
313                count_by_qmax[qmax as usize] += 1;
314            }
315        }
316        std::println!("count_by_p (smaller): {count_by_p:?}");
317        std::println!("count_by_qmax (smaller): {count_by_qmax:?}");
318        let count_fastq = count_by_qmax[0..=Q_FAST_MAX as usize].iter().sum::<usize>();
319        let count_slowq = count_by_qmax[Q_FAST_MAX as usize + 1..]
320            .iter()
321            .sum::<usize>();
322        std::println!("fastq/slowq: {count_fastq}/{count_slowq}");
323        // Assert that 99% of cases resolve to the fast hash
324        assert!(count_fastq >= count_slowq * 100);
325    }
326
327    #[test]
328    fn test_larger() {
329        let mut count_by_p = [0; 256];
330        let mut count_by_qmax = [0; 256];
331        for len in 16..60 {
332            for seed in 0..75 {
333                let keys = random_alphanums(seed, len);
334                let keys_str = core::str::from_utf8(&keys).unwrap();
335                let computed = PerfectByteHashMap::try_new(&keys).expect(keys_str);
336                computed
337                    .check()
338                    .unwrap_or_else(|_| panic!("{}", std::str::from_utf8(&keys).expect(keys_str)));
339                let (p, qmax) = computed.p_qmax().unwrap();
340                count_by_p[p as usize] += 1;
341                count_by_qmax[qmax as usize] += 1;
342            }
343        }
344        std::println!("count_by_p (larger): {count_by_p:?}");
345        std::println!("count_by_qmax (larger): {count_by_qmax:?}");
346        let count_fastq = count_by_qmax[0..=Q_FAST_MAX as usize].iter().sum::<usize>();
347        let count_slowq = count_by_qmax[Q_FAST_MAX as usize + 1..]
348            .iter()
349            .sum::<usize>();
350        std::println!("fastq/slowq: {count_fastq}/{count_slowq}");
351        // Assert that 99% of cases resolve to the fast hash
352        assert!(count_fastq >= count_slowq * 100);
353    }
354
355    #[test]
356    fn test_hard_cases() {
357        let keys = [
358            0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
359            24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
360            46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
361            68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
362            90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
363            109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
364            126, 195, 196,
365        ];
366
367        let computed = PerfectByteHashMap::try_new(&keys).unwrap();
368        let (p, qmax) = computed.p_qmax().unwrap();
369        assert_eq!(p, 69);
370        assert_eq!(qmax, 67);
371    }
372
373    #[test]
374    fn test_build_read_small() {
375        #[derive(Debug)]
376        struct TestCase<'a> {
377            keys: &'a str,
378            expected: &'a [u8],
379            reordered_keys: &'a str,
380        }
381        let cases = [
382            TestCase {
383                keys: "ab",
384                expected: &[0, 0, 0, b'b', b'a'],
385                reordered_keys: "ba",
386            },
387            TestCase {
388                keys: "abc",
389                expected: &[0, 0, 0, 0, b'c', b'a', b'b'],
390                reordered_keys: "cab",
391            },
392            TestCase {
393                // Note: splitting "a" and "c" into different buckets requires the heavier hash
394                // function because the difference between "a" and "c" is the period (2).
395                keys: "ac",
396                expected: &[1, 0, 1, b'c', b'a'],
397                reordered_keys: "ca",
398            },
399            TestCase {
400                keys: "aceg",
401                expected: &[1, 0, 0, 1, 1, b'e', b'a', b'c', b'g'],
402                reordered_keys: "eacg",
403            },
404            TestCase {
405                keys: "abd",
406                expected: &[0, 0, 1, 3, b'a', b'b', b'd'],
407                reordered_keys: "abd",
408            },
409            TestCase {
410                keys: "def",
411                expected: &[0, 0, 0, 0, b'f', b'd', b'e'],
412                reordered_keys: "fde",
413            },
414            TestCase {
415                keys: "fi",
416                expected: &[0, 0, 0, b'f', b'i'],
417                reordered_keys: "fi",
418            },
419            TestCase {
420                keys: "gh",
421                expected: &[0, 0, 0, b'h', b'g'],
422                reordered_keys: "hg",
423            },
424            TestCase {
425                keys: "lm",
426                expected: &[0, 0, 0, b'l', b'm'],
427                reordered_keys: "lm",
428            },
429            TestCase {
430                // Note: "a" and "q" (0x61 and 0x71) are very hard to split; only a handful of
431                // hash function crates can get them into separate buckets.
432                keys: "aq",
433                expected: &[4, 0, 1, b'a', b'q'],
434                reordered_keys: "aq",
435            },
436            TestCase {
437                keys: "xy",
438                expected: &[0, 0, 0, b'x', b'y'],
439                reordered_keys: "xy",
440            },
441            TestCase {
442                keys: "xyz",
443                expected: &[0, 0, 0, 0, b'x', b'y', b'z'],
444                reordered_keys: "xyz",
445            },
446            TestCase {
447                keys: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
448                expected: &[
449                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 10, 12, 16, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 16,
450                    16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
451                    2, 0, 7, 104, 105, 106, 107, 108, 109, 110, 111, 112, 117, 118, 119, 68, 69,
452                    70, 113, 114, 65, 66, 67, 120, 121, 122, 115, 72, 73, 74, 71, 80, 81, 82, 83,
453                    84, 85, 86, 87, 88, 89, 90, 75, 76, 77, 78, 79, 103, 97, 98, 99, 116, 100, 102,
454                    101,
455                ],
456                reordered_keys: "hijklmnopuvwDEFqrABCxyzsHIJGPQRSTUVWXYZKLMNOgabctdfe",
457            },
458            TestCase {
459                keys: "abcdefghij",
460                expected: &[
461                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 100, 101, 102, 103, 104, 105, 106, 97, 98, 99,
462                ],
463                reordered_keys: "defghijabc",
464            },
465            TestCase {
466                // This is a small case that resolves to the slow hasher
467                keys: "Jbej",
468                expected: &[2, 0, 0, 102, 0, b'j', b'e', b'b', b'J'],
469                reordered_keys: "jebJ",
470            },
471            TestCase {
472                // This is another small case that resolves to the slow hasher
473                keys: "JFNv",
474                expected: &[1, 98, 0, 2, 0, b'J', b'F', b'N', b'v'],
475                reordered_keys: "JFNv",
476            },
477        ];
478        for cas in cases {
479            let computed = PerfectByteHashMap::try_new(cas.keys.as_bytes()).expect(cas.keys);
480            assert_eq!(computed.as_bytes(), cas.expected, "{:?}", cas);
481            assert_eq!(computed.keys(), cas.reordered_keys.as_bytes(), "{:?}", cas);
482            computed.check().expect(cas.keys);
483        }
484    }
485}