1#![allow(clippy::missing_const_for_fn)]
2
3use std::{
4 borrow::Borrow,
5 cmp::{Eq, PartialEq},
6 collections::hash_map::{Entry, HashMap},
7 fmt::{self, Debug},
8 hash::Hash,
9 mem,
10 ops::{Index, IndexMut},
11 ptr,
12};
13
14use generational_arena::{Arena, Index as ArenaIndex};
15
16use core_extensions::SelfOps;
17
18#[derive(Clone)]
23pub struct MultiKeyMap<K, T> {
24 map: HashMap<K, MapIndex>,
25 arena: Arena<MapValue<K, T>>,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29struct MapValue<K, T> {
30 keys: Vec<K>,
31 value: T,
32}
33
34#[repr(transparent)]
35#[derive(Debug, Copy, Clone, PartialEq, Eq)]
36pub struct MapIndex {
37 index: ArenaIndex,
38}
39
40#[derive(Debug, Copy, Clone, PartialEq, Eq)]
41pub struct IndexValue<T> {
42 pub index: MapIndex,
43 pub value: T,
44}
45
46#[must_use = "call `.into_inner()` to unwrap into the inner value."]
48#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
49pub enum InsertionTime<T> {
50 Now(T),
51 Before(T),
52}
53
54impl<K, T> MultiKeyMap<K, T>
55where
56 K: Hash + Eq,
57{
58 #[allow(clippy::new_without_default)]
59 pub fn new() -> Self {
60 Self {
61 map: HashMap::default(),
62 arena: Arena::new(),
63 }
64 }
65
66 pub fn get<Q>(&self, key: &Q) -> Option<&T>
67 where
68 K: Borrow<Q>,
69 Q: Hash + Eq + ?Sized,
70 {
71 let &i = self.map.get(key)?;
72 self.get_with_index(i)
73 }
74
75 pub fn get_mut<Q>(&mut self, key: &Q) -> Option<&mut T>
76 where
77 K: Borrow<Q>,
78 Q: Hash + Eq + ?Sized,
79 {
80 let &i = self.map.get(key)?;
81 self.get_mut_with_index(i)
82 }
83
84 #[allow(dead_code)]
85 pub fn get2_mut<Q>(&mut self, key0: &Q, key1: &Q) -> (Option<&mut T>, Option<&mut T>)
86 where
87 K: Borrow<Q>,
88 Q: Hash + Eq + ?Sized,
89 {
90 let i0 = self.map.get(key0).cloned();
91 let i1 = self.map.get(key1).cloned();
92
93 match (i0, i1) {
94 (None, None) => (None, None),
95 (Some(l), None) => (self.get_mut_with_index(l), None),
96 (None, Some(r)) => (None, self.get_mut_with_index(r)),
97 (Some(l), Some(r)) => self.get2_mut_with_index(l, r),
98 }
99 }
100
101 pub fn get_index<Q>(&self, key: &Q) -> Option<MapIndex>
102 where
103 K: Borrow<Q>,
104 Q: Hash + Eq + ?Sized,
105 {
106 self.map.get(key).cloned()
107 }
108
109 pub fn get_with_index(&self, i: MapIndex) -> Option<&T> {
110 self.arena.get(i.index).map(|x| &x.value)
111 }
112
113 pub fn get_mut_with_index(&mut self, i: MapIndex) -> Option<&mut T> {
114 self.arena.get_mut(i.index).map(|x| &mut x.value)
115 }
116
117 pub fn get2_mut_with_index(
118 &mut self,
119 i0: MapIndex,
120 i1: MapIndex,
121 ) -> (Option<&mut T>, Option<&mut T>) {
122 let (l, r) = self.arena.get2_mut(i0.index, i1.index);
123 fn mapper<K, T>(x: &mut MapValue<K, T>) -> &mut T {
124 &mut x.value
125 }
126 (l.map(mapper), r.map(mapper))
127 }
128
129 #[allow(dead_code)]
130 pub fn replace_index(&mut self, replace: MapIndex, with: T) -> Option<T> {
131 self.get_mut_with_index(replace)
132 .map(|x| mem::replace(x, with))
133 }
134
135 #[allow(dead_code)]
136 pub fn key_len(&self) -> usize {
138 self.map.len()
139 }
140
141 #[allow(dead_code)]
142 pub fn value_len(&self) -> usize {
144 self.arena.len()
145 }
146
147 pub fn replace_with_index(&mut self, replace: MapIndex, with: MapIndex) -> Option<T> {
162 if replace == with
163 || !self.arena.contains(replace.index)
164 || !self.arena.contains(with.index)
165 {
166 return None;
167 }
168 let with_ = self.arena.remove(with.index)?;
169 let replaced = self.arena.get_mut(replace.index)?;
170 for key in &with_.keys {
171 *self.map.get_mut(key).unwrap() = replace;
172 }
173 replaced.keys.extend(with_.keys);
174 Some(mem::replace(&mut replaced.value, with_.value))
175 }
176
177 pub fn get_or_insert(&mut self, key: K, value: T) -> InsertionTime<IndexValue<&mut T>>
178 where
179 K: Clone,
180 {
181 match self.map.entry(key.clone()) {
182 Entry::Occupied(entry) => {
183 let index = *entry.get();
184 InsertionTime::Before(IndexValue {
185 index,
186 value: &mut self.arena[index.index].value,
187 })
188 }
189 Entry::Vacant(entry) => {
190 let inserted = MapValue {
191 keys: vec![key],
192 value,
193 };
194 let index = MapIndex::new(self.arena.insert(inserted));
195 entry.insert(index);
196 InsertionTime::Now(IndexValue {
198 index,
199 value: &mut self.arena.get_mut(index.index).unwrap().value,
200 })
201 }
202 }
203 }
204
205 pub fn associate_key(&mut self, key: K, index: MapIndex)
212 where
213 K: Clone,
214 {
215 let value = match self.arena.get_mut(index.index) {
216 Some(x) => x,
217 None => panic!("Invalid index:{:?}", index),
218 };
219 match self.map.entry(key.clone()) {
220 Entry::Occupied(_) => {}
221 Entry::Vacant(entry) => {
222 entry.insert(index);
223 value.keys.push(key);
224 }
225 }
226 }
227
228 #[allow(dead_code)]
242 pub fn associate_key_forced(&mut self, key: K, index: MapIndex) -> Option<T>
243 where
244 K: Clone + ::std::fmt::Debug,
245 {
246 assert!(
247 self.arena.contains(index.index),
248 "Invalid index:{:?}",
249 index,
250 );
251 let ret = match self.map.entry(key.clone()) {
252 Entry::Occupied(mut entry) => {
253 let index_before = *entry.get();
254 entry.insert(index);
255 let slot = &mut self.arena[index_before.index];
256 let key_ind = slot.keys.iter().position(|x| *x == key).unwrap();
257 slot.keys.swap_remove(key_ind);
258 if slot.keys.is_empty() {
259 self.arena
260 .remove(index_before.index)
261 .unwrap()
262 .value
263 .piped(Some)
264 } else {
265 None
266 }
267 }
268 Entry::Vacant(_) => None,
269 };
270 let value = &mut self.arena[index.index];
271 self.map.entry(key.clone()).or_insert(index);
272 value.keys.push(key);
273 ret
274 }
275}
276
277impl<'a, K, Q: ?Sized, T> Index<&'a Q> for MultiKeyMap<K, T>
278where
279 K: Eq + Hash + Borrow<Q>,
280 Q: Eq + Hash,
281{
282 type Output = T;
283
284 fn index(&self, index: &'a Q) -> &T {
285 self.get(index).expect("no entry found for key")
286 }
287}
288
289impl<'a, K, Q: ?Sized, T> IndexMut<&'a Q> for MultiKeyMap<K, T>
290where
291 K: Eq + Hash + Borrow<Q>,
292 Q: Eq + Hash,
293{
294 fn index_mut(&mut self, index: &'a Q) -> &mut T {
295 self.get_mut(index).expect("no entry found for key")
296 }
297}
298
299impl<K, T> Debug for MultiKeyMap<K, T>
300where
301 K: Eq + Hash + Debug,
302 T: Debug,
303{
304 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305 f.debug_struct("MultiKeyMap")
306 .field("map", &self.map)
307 .field("arena", &self.arena)
308 .finish()
309 }
310}
311
312impl<K, T> Eq for MultiKeyMap<K, T>
313where
314 K: Eq + Hash,
315 T: Eq,
316{
317}
318
319impl<K, T> PartialEq for MultiKeyMap<K, T>
320where
321 K: Eq + Hash,
322 T: PartialEq,
323{
324 fn eq(&self, other: &Self) -> bool {
325 if self.arena.len() != other.arena.len() || self.map.len() != other.map.len() {
326 return false;
327 }
328 for (_, l_val) in self.arena.iter() {
329 let mut keys = l_val.keys.iter();
330
331 let r_val_index = match other.get_index(keys.next().unwrap()) {
332 Some(x) => x,
333 None => return false,
334 };
335
336 let r_val = &other.arena[r_val_index.index];
337
338 if l_val.value != r_val.value {
339 return false;
340 }
341
342 let all_map_to_r_val = keys.all(|key| match other.get_index(key) {
343 Some(r_ind) => ptr::eq(r_val, &other.arena[r_ind.index]),
344 None => false,
345 });
346
347 if !all_map_to_r_val {
348 return false;
349 }
350 }
351 true
352 }
353}
354
355impl MapIndex {
356 #[inline]
357 fn new(index: ArenaIndex) -> Self {
358 Self { index }
359 }
360}
361
362impl<T> InsertionTime<T> {
363 pub fn into_inner(self) -> T {
364 match self {
365 InsertionTime::Before(v) | InsertionTime::Now(v) => v,
366 }
367 }
368 #[allow(dead_code)]
369 pub fn split(self) -> (T, InsertionTime<()>) {
370 let discr = self.discriminant();
371 (self.into_inner(), discr)
372 }
373 #[allow(dead_code)]
374 pub fn map<F, U>(self, f: F) -> InsertionTime<U>
375 where
376 F: FnOnce(T) -> U,
377 {
378 match self {
379 InsertionTime::Before(v) => InsertionTime::Before(f(v)),
380 InsertionTime::Now(v) => InsertionTime::Now(f(v)),
381 }
382 }
383 #[allow(dead_code)]
384 pub fn discriminant(&self) -> InsertionTime<()> {
385 match self {
386 InsertionTime::Before { .. } => InsertionTime::Before(()),
387 InsertionTime::Now { .. } => InsertionTime::Now(()),
388 }
389 }
390}
391
392#[cfg(all(test, not(feature = "only_new_tests")))]
393mod tests {
394 use super::*;
395
396 use crate::test_utils::must_panic;
397
398 #[test]
399 fn equality() {
400 fn insert(map: &mut MultiKeyMap<u32, u32>, key: u32, value: u32) {
401 let index = map.get_or_insert(key, value).into_inner().index;
402 map.associate_key(key + 1, index);
403 map.associate_key(key + 2, index);
404 }
405
406 {
410 let map_a = MultiKeyMap::<u32, u32>::new();
411
412 let map_b = MultiKeyMap::<u32, u32>::new();
413
414 assert_eq!(map_a, map_b);
415 }
416 {
417 let mut map_a = MultiKeyMap::<u32, u32>::new();
418 insert(&mut map_a, 1000, 200);
419
420 let mut map_b = MultiKeyMap::<u32, u32>::new();
421 insert(&mut map_b, 1000, 200);
422
423 assert_eq!(map_a, map_b);
424 }
425 {
426 let mut map_a = MultiKeyMap::<u32, u32>::new();
427 insert(&mut map_a, 1000, 200);
428 insert(&mut map_a, 2000, 400);
429
430 let mut map_b = MultiKeyMap::<u32, u32>::new();
431 insert(&mut map_b, 1000, 200);
432 insert(&mut map_b, 2000, 400);
433
434 assert_eq!(map_a, map_b);
435 }
436
437 {
441 let map_a = MultiKeyMap::<u32, u32>::new();
442
443 let mut map_b = MultiKeyMap::<u32, u32>::new();
444 insert(&mut map_b, 1000, 200);
445
446 assert_ne!(map_a, map_b);
447 }
448 {
449 let mut map_a = MultiKeyMap::<u32, u32>::new();
450 insert(&mut map_a, 1000, 200);
451
452 let map_b = MultiKeyMap::<u32, u32>::new();
453
454 assert_ne!(map_a, map_b);
455 }
456 {
457 let mut map_a = MultiKeyMap::<u32, u32>::new();
458 insert(&mut map_a, 1000, 200);
459 insert(&mut map_a, 2000, 401);
460
461 let mut map_b = MultiKeyMap::<u32, u32>::new();
462 insert(&mut map_b, 1000, 200);
463 insert(&mut map_b, 2000, 400);
464
465 assert_ne!(map_a, map_b);
466 }
467 }
468
469 #[test]
470 fn get_or_insert() {
471 let mut map = MultiKeyMap::<u32, u32>::new();
472
473 let (ret, ret_discr) = map.get_or_insert(10, 1).split();
474 *ret.value = 1234;
475 assert_matches!(ret_discr, InsertionTime::Now { .. });
476
477 assert_matches!(
478 map.get_or_insert(10, 2).map(|x| x.value).split(),
479 (&mut 1234, InsertionTime::Before { .. })
480 );
481 assert_matches!(
482 map.get_or_insert(10, 3).map(|x| x.value).split(),
483 (&mut 1234, InsertionTime::Before { .. })
484 );
485 }
486
487 #[test]
488 fn associate_key() {
489 let mut map = MultiKeyMap::<u32, u32>::new();
490
491 let (ret, ret_discr) = map.get_or_insert(100, 1).split();
492 let index0 = ret.index;
493 *ret.value = 1234;
494 assert_matches!(ret_discr, InsertionTime::Now { .. });
495
496 let index1 = map.get_or_insert(200, 200).into_inner().index;
497 let index2 = map.get_or_insert(300, 300).into_inner().index;
498
499 map.associate_key(20, index0);
500 map.associate_key(20, index1);
501 map.associate_key(20, index2);
502 assert_eq!(map[&20], 1234);
503
504 map.associate_key(30, index0);
505 map.associate_key(30, index1);
506 map.associate_key(30, index2);
507 assert_eq!(map[&30], 1234);
508
509 map.associate_key(50, index2);
510 map.associate_key(50, index0);
511 map.associate_key(50, index1);
512 assert_eq!(map[&50], 300);
513
514 map[&100] = 456;
515 assert_eq!(map[&20], 456);
516 assert_eq!(map[&30], 456);
517 }
518
519 #[test]
520 fn associate_key_forced() {
521 let mut map = MultiKeyMap::<u32, u32>::new();
522
523 let index0 = map.get_or_insert(100, 1000).into_inner().index;
524 let index1 = map.get_or_insert(200, 2000).into_inner().index;
525 let index2 = map.get_or_insert(300, 3000).into_inner().index;
526
527 assert_eq!(map.associate_key_forced(20, index2), None);
528 assert_eq!(map.associate_key_forced(20, index1), None);
529 assert_eq!(map.associate_key_forced(20, index0), None);
530 assert_eq!(map[&20], 1000);
531
532 assert_eq!(map.associate_key_forced(30, index2), None);
533 assert_eq!(map.associate_key_forced(30, index0), None);
534 assert_eq!(map.associate_key_forced(30, index1), None);
535 assert_eq!(map[&30], 2000);
536
537 assert_eq!(map.associate_key_forced(50, index1), None);
538 assert_eq!(map.associate_key_forced(50, index0), None);
539 assert_eq!(map.associate_key_forced(50, index2), None);
540 assert_eq!(map[&50], 3000);
541
542 assert_eq!(map.associate_key_forced(100, index2), None);
543 assert_eq!(map.associate_key_forced(20, index2), Some(1000));
544
545 assert_eq!(map.associate_key_forced(200, index2), None);
546 assert_eq!(map.associate_key_forced(30, index2), Some(2000));
547
548 must_panic(|| map.associate_key_forced(100, index0)).unwrap();
549 must_panic(|| map.associate_key_forced(200, index0)).unwrap();
550 must_panic(|| map.associate_key_forced(20, index0)).unwrap();
551 must_panic(|| map.associate_key_forced(30, index0)).unwrap();
552 }
553
554 #[test]
555 fn replace_index() {
556 let mut map = MultiKeyMap::<u32, u32>::new();
557
558 let index0 = map.get_or_insert(1000, 200).into_inner().index;
559 map.associate_key(1001, index0);
560 map.associate_key(1002, index0);
561
562 let index1 = map.get_or_insert(2000, 300).into_inner().index;
563 map.associate_key(2001, index1);
564 map.associate_key(2002, index1);
565
566 let index2 = map.get_or_insert(3000, 400).into_inner().index;
567 map.associate_key(3001, index2);
568 map.associate_key(3002, index2);
569
570 assert_eq!(map[&1000], 200);
571 assert_eq!(map[&1001], 200);
572 assert_eq!(map[&1001], 200);
573
574 map.replace_index(index0, 205);
575 assert_eq!(map[&1000], 205);
576 assert_eq!(map[&1001], 205);
577 assert_eq!(map[&1001], 205);
578
579 map.replace_index(index1, 305);
580 assert_eq!(map[&2000], 305);
581 assert_eq!(map[&2001], 305);
582 assert_eq!(map[&2001], 305);
583
584 map.replace_index(index2, 405);
585 assert_eq!(map[&3000], 405);
586 assert_eq!(map[&3001], 405);
587 assert_eq!(map[&3001], 405);
588 }
589
590 #[test]
591 fn replace_with_index() {
592 let mut map = MultiKeyMap::<u32, u32>::new();
593
594 let index0 = map.get_or_insert(1000, 200).into_inner().index;
595 map.associate_key(1001, index0);
596 map.associate_key(1002, index0);
597
598 let index1 = map.get_or_insert(2000, 300).into_inner().index;
599 map.associate_key(2001, index1);
600 map.associate_key(2002, index1);
601
602 let index2 = map.get_or_insert(3000, 400).into_inner().index;
603 map.associate_key(3001, index2);
604 map.associate_key(3002, index2);
605
606 map.replace_with_index(index0, index2);
607 assert_eq!(map[&1000], 400);
608 assert_eq!(map[&1001], 400);
609 assert_eq!(map[&1002], 400);
610 assert_eq!(map[&2000], 300);
611 assert_eq!(map[&2001], 300);
612 assert_eq!(map[&2002], 300);
613 assert_eq!(map[&3000], 400);
614 assert_eq!(map[&3001], 400);
615 assert_eq!(map[&3002], 400);
616 map[&1000] = 600;
617 assert_eq!(map[&1000], 600);
618 assert_eq!(map[&1001], 600);
619 assert_eq!(map[&1002], 600);
620 assert_eq!(map[&2000], 300);
621 assert_eq!(map[&2001], 300);
622 assert_eq!(map[&2002], 300);
623 assert_eq!(map[&3000], 600);
624 assert_eq!(map[&3001], 600);
625 assert_eq!(map[&3002], 600);
626
627 map.replace_with_index(index1, index0);
628 map[&1000] = 800;
629 assert_eq!(map[&1000], 800);
630 assert_eq!(map[&1001], 800);
631 assert_eq!(map[&1002], 800);
632 assert_eq!(map[&2000], 800);
633 assert_eq!(map[&2001], 800);
634 assert_eq!(map[&2002], 800);
635 assert_eq!(map[&3000], 800);
636 assert_eq!(map[&3001], 800);
637 assert_eq!(map[&3002], 800);
638 }
639
640 #[test]
641 fn indexing() {
642 let mut map = MultiKeyMap::<u32, u32>::new();
643
644 let (index0, it0) = map.get_or_insert(1000, 200).map(|x| x.index).split();
645 let (index1, it1) = map.get_or_insert(2000, 300).map(|x| x.index).split();
646 let (index2, it2) = map.get_or_insert(3000, 400).map(|x| x.index).split();
647
648 assert_eq!(it0, InsertionTime::Now(()));
649 assert_eq!(it1, InsertionTime::Now(()));
650 assert_eq!(it2, InsertionTime::Now(()));
651
652 let expected = vec![
653 (1000, index0, 200),
654 (2000, index1, 300),
655 (3000, index2, 400),
656 ];
657 #[allow(clippy::deref_addrof)]
658 for (key, index, val) in expected {
659 assert_eq!(*map.get_with_index(index).unwrap(), val);
660 assert_eq!(*map.get(&key).unwrap(), val);
661 assert_eq!(*(&map[&key]), val);
662
663 assert_eq!(*map.get_mut(&key).unwrap(), val);
664 assert_eq!(*map.get_mut_with_index(index).unwrap(), val);
665 assert_eq!(*(&mut map[&key]), val);
666 }
667 }
668}