kdtree/
kdtree.rs

1use std::collections::BinaryHeap;
2
3use num_traits::{Float, One, Zero};
4use thiserror::Error;
5
6use crate::heap_element::HeapElement;
7use crate::util;
8
9#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
10#[derive(Clone, Debug)]
11pub struct KdTree<A, T: std::cmp::PartialEq, U: AsRef<[A]> + std::cmp::PartialEq> {
12    // node
13    left: Option<Box<KdTree<A, T, U>>>,
14    right: Option<Box<KdTree<A, T, U>>>,
15    // common
16    dimensions: usize,
17    capacity: usize,
18    size: usize,
19    min_bounds: Box<[A]>,
20    max_bounds: Box<[A]>,
21    // stem
22    split_value: Option<A>,
23    split_dimension: Option<usize>,
24    // leaf
25    points: Option<Vec<U>>,
26    bucket: Option<Vec<T>>,
27}
28
29#[derive(Error, Debug, PartialEq, Eq)]
30pub enum ErrorKind {
31    #[error("wrong dimension")]
32    WrongDimension,
33    #[error("non-finite coordinate")]
34    NonFiniteCoordinate,
35    #[error("zero capacity")]
36    ZeroCapacity,
37}
38
39impl<A: Float + Zero + One, T: std::cmp::PartialEq, U: AsRef<[A]> + std::cmp::PartialEq> KdTree<A, T, U> {
40    /// Create a new KD tree, specifying the dimension size of each point
41    pub fn new(dims: usize) -> Self {
42        KdTree::with_capacity(dims, 2_usize.pow(4))
43    }
44
45    /// Create a new KD tree, specifying the dimension size of each point and the capacity of leaf nodes
46    pub fn with_capacity(dimensions: usize, capacity: usize) -> Self {
47        let min_bounds = vec![A::infinity(); dimensions];
48        let max_bounds = vec![A::neg_infinity(); dimensions];
49        KdTree {
50            left: None,
51            right: None,
52            dimensions,
53            capacity,
54            size: 0,
55            min_bounds: min_bounds.into_boxed_slice(),
56            max_bounds: max_bounds.into_boxed_slice(),
57            split_value: None,
58            split_dimension: None,
59            points: Some(vec![]),
60            bucket: Some(vec![]),
61        }
62    }
63
64    pub fn size(&self) -> usize {
65        self.size
66    }
67
68    pub fn nearest<F>(&self, point: &[A], num: usize, distance: &F) -> Result<Vec<(A, &T)>, ErrorKind>
69    where
70        F: Fn(&[A], &[A]) -> A,
71    {
72        self.check_point(point)?;
73        let num = std::cmp::min(num, self.size);
74        if num == 0 {
75            return Ok(vec![]);
76        }
77        let mut pending = BinaryHeap::new();
78        let mut evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
79        pending.push(HeapElement {
80            distance: A::zero(),
81            element: self,
82        });
83        while !pending.is_empty()
84            && (evaluated.len() < num || (-pending.peek().unwrap().distance <= evaluated.peek().unwrap().distance))
85        {
86            self.nearest_step(point, num, A::infinity(), distance, &mut pending, &mut evaluated);
87        }
88        Ok(evaluated
89            .into_sorted_vec()
90            .into_iter()
91            .take(num)
92            .map(Into::into)
93            .collect())
94    }
95
96    pub fn within<F>(&self, point: &[A], radius: A, distance: &F) -> Result<Vec<(A, &T)>, ErrorKind>
97    where
98        F: Fn(&[A], &[A]) -> A,
99    {
100        self.check_point(point)?;
101        if self.size == 0 {
102            return Ok(vec![]);
103        }
104        let mut pending = BinaryHeap::new();
105        let mut evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
106        pending.push(HeapElement {
107            distance: A::zero(),
108            element: self,
109        });
110        while !pending.is_empty() && (-pending.peek().unwrap().distance <= radius) {
111            self.nearest_step(point, self.size, radius, distance, &mut pending, &mut evaluated);
112        }
113        Ok(evaluated.into_sorted_vec().into_iter().map(Into::into).collect())
114    }
115
116    fn nearest_step<'b, F>(
117        &self,
118        point: &[A],
119        num: usize,
120        max_dist: A,
121        distance: &F,
122        pending: &mut BinaryHeap<HeapElement<A, &'b Self>>,
123        evaluated: &mut BinaryHeap<HeapElement<A, &'b T>>,
124    ) where
125        F: Fn(&[A], &[A]) -> A,
126    {
127        let mut curr = pending.pop().unwrap().element;
128        debug_assert!(evaluated.len() <= num);
129        let evaluated_dist = if evaluated.len() == num {
130            // We only care about the nearest `num` points, so if we already have `num` points,
131            // any more point we add to `evaluated` must be nearer then one of the point already in
132            // `evaluated`.
133            max_dist.min(evaluated.peek().unwrap().distance)
134        } else {
135            max_dist
136        };
137
138        while !curr.is_leaf() {
139            let candidate;
140            if curr.belongs_in_left(point) {
141                candidate = curr.right.as_ref().unwrap();
142                curr = curr.left.as_ref().unwrap();
143            } else {
144                candidate = curr.left.as_ref().unwrap();
145                curr = curr.right.as_ref().unwrap();
146            }
147            let candidate_to_space =
148                util::distance_to_space(point, &candidate.min_bounds, &candidate.max_bounds, distance);
149            if candidate_to_space <= evaluated_dist {
150                pending.push(HeapElement {
151                    distance: candidate_to_space * -A::one(),
152                    element: &**candidate,
153                });
154            }
155        }
156
157        let points = curr.points.as_ref().unwrap().iter();
158        let bucket = curr.bucket.as_ref().unwrap().iter();
159        let iter = points.zip(bucket).map(|(p, d)| HeapElement {
160            distance: distance(point, p.as_ref()),
161            element: d,
162        });
163        for element in iter {
164            if element <= max_dist {
165                if evaluated.len() < num {
166                    evaluated.push(element);
167                } else if element < *evaluated.peek().unwrap() {
168                    evaluated.pop();
169                    evaluated.push(element);
170                }
171            }
172        }
173    }
174
175    pub fn iter_nearest<'a, 'b, F>(
176        &'b self,
177        point: &'a [A],
178        distance: &'a F,
179    ) -> Result<NearestIter<'a, 'b, A, T, U, F>, ErrorKind>
180    where
181        F: Fn(&[A], &[A]) -> A,
182    {
183        self.check_point(point)?;
184        let mut pending = BinaryHeap::new();
185        let evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
186        pending.push(HeapElement {
187            distance: A::zero(),
188            element: self,
189        });
190        Ok(NearestIter {
191            point,
192            pending,
193            evaluated,
194            distance,
195        })
196    }
197
198    pub fn iter_nearest_mut<'a, 'b, F>(
199        &'b mut self,
200        point: &'a [A],
201        distance: &'a F,
202    ) -> Result<NearestIterMut<'a, 'b, A, T, U, F>, ErrorKind>
203    where
204        F: Fn(&[A], &[A]) -> A,
205    {
206        self.check_point(point)?;
207        let mut pending = BinaryHeap::new();
208        let evaluated = BinaryHeap::<HeapElement<A, &mut T>>::new();
209        pending.push(HeapElement {
210            distance: A::zero(),
211            element: self,
212        });
213        Ok(NearestIterMut {
214            point,
215            pending,
216            evaluated,
217            distance,
218        })
219    }
220
221    pub fn add(&mut self, point: U, data: T) -> Result<(), ErrorKind> {
222        if self.capacity == 0 {
223            return Err(ErrorKind::ZeroCapacity);
224        }
225        self.check_point(point.as_ref())?;
226        self.add_unchecked(point, data)
227    }
228
229    fn add_unchecked(&mut self, point: U, data: T) -> Result<(), ErrorKind> {
230        if self.is_leaf() {
231            self.add_to_bucket(point, data);
232            return Ok(());
233        }
234        self.extend(point.as_ref());
235        self.size += 1;
236        let next = if self.belongs_in_left(point.as_ref()) {
237            self.left.as_mut()
238        } else {
239            self.right.as_mut()
240        };
241        next.unwrap().add_unchecked(point, data)
242    }
243
244    fn add_to_bucket(&mut self, point: U, data: T) {
245        self.extend(point.as_ref());
246        let mut points = self.points.take().unwrap();
247        let mut bucket = self.bucket.take().unwrap();
248        points.push(point);
249        bucket.push(data);
250        self.size += 1;
251        if self.size > self.capacity {
252            self.split(points, bucket);
253        } else {
254            self.points = Some(points);
255            self.bucket = Some(bucket);
256        }
257    }
258
259    pub fn remove(&mut self, point: &U, data: &T) -> Result<usize, ErrorKind> {
260        let mut removed = 0;
261        self.check_point(point.as_ref())?;
262        if let (Some(mut points), Some(mut bucket)) = (self.points.take(), self.bucket.take()) {
263            while let Some(p_index) = points.iter().position(|x| x == point) {
264                if &bucket[p_index] == data {
265                    points.remove(p_index);
266                    bucket.remove(p_index);
267                    removed += 1;
268                    self.size -= 1;
269                }
270            }
271            self.points = Some(points);
272            self.bucket = Some(bucket);
273        } else {
274            if let Some(right) = self.right.as_mut() {
275                let right_removed = right.remove(point, data)?;
276                if right_removed > 0 {
277                    self.size -= right_removed;
278                    removed += right_removed;
279                }
280            }
281            if let Some(left) = self.left.as_mut() {
282                let left_removed = left.remove(point, data)?;
283                if left_removed > 0 {
284                    self.size -= left_removed;
285                    removed += left_removed;
286                }
287            }
288        }
289        Ok(removed)
290    }
291
292    fn split(&mut self, mut points: Vec<U>, mut bucket: Vec<T>) {
293        let mut max = A::zero();
294        for dim in 0..self.dimensions {
295            let diff = self.max_bounds[dim] - self.min_bounds[dim];
296            if !diff.is_nan() && diff > max {
297                max = diff;
298                self.split_dimension = Some(dim);
299            }
300        }
301        match self.split_dimension {
302            None => {
303                self.points = Some(points);
304                self.bucket = Some(bucket);
305                return;
306            }
307            Some(dim) => {
308                let min = self.min_bounds[dim];
309                let max = self.max_bounds[dim];
310                self.split_value = Some(min + (max - min) / A::from(2.0).unwrap());
311            }
312        };
313        let mut left = Box::new(KdTree::with_capacity(self.dimensions, self.capacity));
314        let mut right = Box::new(KdTree::with_capacity(self.dimensions, self.capacity));
315        while !points.is_empty() {
316            let point = points.swap_remove(0);
317            let data = bucket.swap_remove(0);
318            if self.belongs_in_left(point.as_ref()) {
319                left.add_to_bucket(point, data);
320            } else {
321                right.add_to_bucket(point, data);
322            }
323        }
324        self.left = Some(left);
325        self.right = Some(right);
326    }
327
328    fn belongs_in_left(&self, point: &[A]) -> bool {
329        if self.min_bounds[self.split_dimension.unwrap()] == self.split_value.unwrap() {
330            point[self.split_dimension.unwrap()] <= self.split_value.unwrap()
331        } else {
332            point[self.split_dimension.unwrap()] < self.split_value.unwrap()
333        }
334    }
335
336    fn extend(&mut self, point: &[A]) {
337        let min = self.min_bounds.iter_mut();
338        let max = self.max_bounds.iter_mut();
339        for ((l, h), v) in min.zip(max).zip(point.iter()) {
340            if v < l {
341                *l = *v
342            }
343            if v > h {
344                *h = *v
345            }
346        }
347    }
348
349    fn is_leaf(&self) -> bool {
350        self.bucket.is_some()
351            && self.points.is_some()
352            && self.split_value.is_none()
353            && self.split_dimension.is_none()
354            && self.left.is_none()
355            && self.right.is_none()
356    }
357
358    fn check_point(&self, point: &[A]) -> Result<(), ErrorKind> {
359        if self.dimensions != point.len() {
360            return Err(ErrorKind::WrongDimension);
361        }
362        for n in point {
363            if !n.is_finite() {
364                return Err(ErrorKind::NonFiniteCoordinate);
365            }
366        }
367        Ok(())
368    }
369}
370
371pub struct NearestIter<
372    'a,
373    'b,
374    A: 'a + 'b + Float,
375    T: 'b + PartialEq,
376    U: 'b + AsRef<[A]> + std::cmp::PartialEq,
377    F: 'a + Fn(&[A], &[A]) -> A,
378> {
379    point: &'a [A],
380    pending: BinaryHeap<HeapElement<A, &'b KdTree<A, T, U>>>,
381    evaluated: BinaryHeap<HeapElement<A, &'b T>>,
382    distance: &'a F,
383}
384
385impl<'a, 'b, A: Float + Zero + One, T: 'b, U: 'b + AsRef<[A]>, F: 'a> Iterator for NearestIter<'a, 'b, A, T, U, F>
386where
387    F: Fn(&[A], &[A]) -> A,
388    U: PartialEq,
389    T: PartialEq,
390{
391    type Item = (A, &'b T);
392    fn next(&mut self) -> Option<(A, &'b T)> {
393        use util::distance_to_space;
394
395        let distance = self.distance;
396        let point = self.point;
397        while !self.pending.is_empty()
398            && (self.evaluated.peek().map_or(A::infinity(), |x| -x.distance) >= -self.pending.peek().unwrap().distance)
399        {
400            let mut curr = self.pending.pop().unwrap().element;
401            while !curr.is_leaf() {
402                let candidate;
403                if curr.belongs_in_left(point) {
404                    candidate = curr.right.as_ref().unwrap();
405                    curr = curr.left.as_ref().unwrap();
406                } else {
407                    candidate = curr.left.as_ref().unwrap();
408                    curr = curr.right.as_ref().unwrap();
409                }
410                self.pending.push(HeapElement {
411                    distance: -distance_to_space(point, &candidate.min_bounds, &candidate.max_bounds, distance),
412                    element: &**candidate,
413                });
414            }
415            let points = curr.points.as_ref().unwrap().iter();
416            let bucket = curr.bucket.as_ref().unwrap().iter();
417            self.evaluated.extend(points.zip(bucket).map(|(p, d)| HeapElement {
418                distance: -distance(point, p.as_ref()),
419                element: d,
420            }));
421        }
422        self.evaluated.pop().map(|x| (-x.distance, x.element))
423    }
424}
425
426pub struct NearestIterMut<
427    'a,
428    'b,
429    A: 'a + 'b + Float,
430    T: 'b + PartialEq,
431    U: 'b + AsRef<[A]> + PartialEq,
432    F: 'a + Fn(&[A], &[A]) -> A,
433> {
434    point: &'a [A],
435    pending: BinaryHeap<HeapElement<A, &'b mut KdTree<A, T, U>>>,
436    evaluated: BinaryHeap<HeapElement<A, &'b mut T>>,
437    distance: &'a F,
438}
439
440impl<'a, 'b, A: Float + Zero + One, T: 'b, U: 'b + AsRef<[A]>, F: 'a> Iterator for NearestIterMut<'a, 'b, A, T, U, F>
441where
442    F: Fn(&[A], &[A]) -> A,
443    U: PartialEq,
444    T: PartialEq,
445{
446    type Item = (A, &'b mut T);
447    fn next(&mut self) -> Option<(A, &'b mut T)> {
448        use util::distance_to_space;
449
450        let distance = self.distance;
451        let point = self.point;
452        while !self.pending.is_empty()
453            && (self.evaluated.peek().map_or(A::infinity(), |x| -x.distance) >= -self.pending.peek().unwrap().distance)
454        {
455            let mut curr = &mut *self.pending.pop().unwrap().element;
456            while !curr.is_leaf() {
457                let candidate;
458                if curr.belongs_in_left(point) {
459                    candidate = curr.right.as_mut().unwrap();
460                    curr = curr.left.as_mut().unwrap();
461                } else {
462                    candidate = curr.left.as_mut().unwrap();
463                    curr = curr.right.as_mut().unwrap();
464                }
465                self.pending.push(HeapElement {
466                    distance: -distance_to_space(point, &candidate.min_bounds, &candidate.max_bounds, distance),
467                    element: &mut **candidate,
468                });
469            }
470            let points = curr.points.as_ref().unwrap().iter();
471            let bucket = curr.bucket.as_mut().unwrap().iter_mut();
472            self.evaluated.extend(points.zip(bucket).map(|(p, d)| HeapElement {
473                distance: -distance(point, p.as_ref()),
474                element: d,
475            }));
476        }
477        self.evaluated.pop().map(|x| (-x.distance, x.element))
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    extern crate rand;
484    use super::KdTree;
485
486    fn random_point() -> ([f64; 2], i32) {
487        rand::random::<([f64; 2], i32)>()
488    }
489
490    #[test]
491    fn it_has_default_capacity() {
492        let tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
493        assert_eq!(tree.capacity, 2_usize.pow(4));
494    }
495
496    #[test]
497    fn it_can_be_cloned() {
498        let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
499        let (pos, data) = random_point();
500        tree.add(pos, data).unwrap();
501        let mut cloned_tree = tree.clone();
502        cloned_tree.add(pos, data).unwrap();
503        assert_eq!(tree.size(), 1);
504        assert_eq!(cloned_tree.size(), 2);
505    }
506
507    #[test]
508    fn it_holds_on_to_its_capacity_before_splitting() {
509        let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
510        let capacity = 2_usize.pow(4);
511        for _ in 0..capacity {
512            let (pos, data) = random_point();
513            tree.add(pos, data).unwrap();
514        }
515        assert_eq!(tree.size, capacity);
516        assert_eq!(tree.size(), capacity);
517        assert!(tree.left.is_none() && tree.right.is_none());
518        {
519            let (pos, data) = random_point();
520            tree.add(pos, data).unwrap();
521        }
522        assert_eq!(tree.size, capacity + 1);
523        assert_eq!(tree.size(), capacity + 1);
524        assert!(tree.left.is_some() && tree.right.is_some());
525    }
526
527    #[test]
528    fn no_items_can_be_added_to_a_zero_capacity_kdtree() {
529        let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::with_capacity(2, 0);
530        let (pos, data) = random_point();
531        let res = tree.add(pos, data);
532        assert!(res.is_err());
533    }
534
535    #[test]
536    fn avoid_infinite_call_loop_between_add_to_bucket_and_split_due_to_float_accuracy() {
537        {
538            let min = 0.47945351705599926f64;
539            let max = 0.479_453_517_055_999_3_f64;
540
541            let mut tree = KdTree::with_capacity(1, 2);
542            tree.add([min], ()).unwrap();
543            tree.add([max], ()).unwrap();
544
545            tree.add([min], ()).unwrap();
546            tree.add([max], ()).unwrap();
547        }
548        {
549            let min = -0.479_453_517_055_999_3_f64;
550            let max = -0.47945351705599926f64;
551
552            let mut tree = KdTree::with_capacity(1, 2);
553            tree.add([min], ()).unwrap();
554            tree.add([max], ()).unwrap();
555
556            tree.add([min], ()).unwrap();
557            tree.add([max], ()).unwrap();
558        }
559    }
560}