abi_stable_derive/lifetimes/
lifetime_counters.rs1use super::LifetimeIndex;
2
3use std::fmt::{self, Debug};
4
5pub(crate) struct LifetimeCounters {
7 set: Vec<u8>,
8}
9
10const MASK: u8 = 0b11;
11const MAX_VAL: u8 = 3;
12
13impl LifetimeCounters {
14 pub fn new() -> Self {
15 Self { set: Vec::new() }
16 }
17 pub fn increment(&mut self, lifetime: LifetimeIndex) -> u8 {
19 let (i, shift) = Self::get_index_shift(lifetime.bits);
20 if i >= self.set.len() {
21 self.set.resize(i + 1, 0);
22 }
23 let bits = &mut self.set[i];
24 let mask = MASK << shift;
25 if (*bits & mask) == mask {
26 MAX_VAL
27 } else {
28 *bits += 1 << shift;
29 (*bits >> shift) & MASK
30 }
31 }
32
33 pub fn get(&self, lifetime: LifetimeIndex) -> u8 {
34 let (i, shift) = Self::get_index_shift(lifetime.bits);
35 match self.set.get(i) {
36 Some(&bits) => (bits >> shift) & MASK,
37 None => 0,
38 }
39 }
40
41 fn get_index_shift(lt: usize) -> (usize, u8) {
42 (lt >> 2, ((lt & 3) << 1) as u8)
43 }
44}
45
46impl Debug for LifetimeCounters {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 f.debug_list()
49 .entries(self.set.iter().cloned().map(U8Wrapper))
50 .finish()
51 }
52}
53
54#[repr(transparent)]
55struct U8Wrapper(u8);
56
57impl fmt::Debug for U8Wrapper {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 fmt::Binary::fmt(&self.0, f)
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn test_counting() {
69 let mut counters = LifetimeCounters::new();
70
71 let lts = vec![
72 LifetimeIndex::Param(0),
73 LifetimeIndex::Param(1),
74 LifetimeIndex::Param(2),
75 LifetimeIndex::Param(3),
76 LifetimeIndex::Param(4),
77 LifetimeIndex::Param(5),
78 LifetimeIndex::Param(6),
79 LifetimeIndex::Param(7),
80 LifetimeIndex::Param(8),
81 LifetimeIndex::Param(9),
82 LifetimeIndex::Param(999),
83 LifetimeIndex::ANONYMOUS,
84 LifetimeIndex::STATIC,
85 LifetimeIndex::NONE,
86 ];
87
88 for lt in lts {
89 for i in 1..=3 {
90 assert_eq!(counters.get(lt), i - 1);
91 assert_eq!(counters.increment(lt), i);
92 assert_eq!(counters.get(lt), i);
93 }
94 }
95 }
96}