1use std::collections::BinaryHeap;
2
3use num_traits::{Float, One, Zero};
4use thiserror::Error;
5
6use crate::heap_element::HeapElement;
7use crate::util;
8
9#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
10#[derive(Clone, Debug)]
11pub struct KdTree<A, T: std::cmp::PartialEq, U: AsRef<[A]> + std::cmp::PartialEq> {
12 left: Option<Box<KdTree<A, T, U>>>,
14 right: Option<Box<KdTree<A, T, U>>>,
15 dimensions: usize,
17 capacity: usize,
18 size: usize,
19 min_bounds: Box<[A]>,
20 max_bounds: Box<[A]>,
21 split_value: Option<A>,
23 split_dimension: Option<usize>,
24 points: Option<Vec<U>>,
26 bucket: Option<Vec<T>>,
27}
28
29#[derive(Error, Debug, PartialEq, Eq)]
30pub enum ErrorKind {
31 #[error("wrong dimension")]
32 WrongDimension,
33 #[error("non-finite coordinate")]
34 NonFiniteCoordinate,
35 #[error("zero capacity")]
36 ZeroCapacity,
37}
38
39impl<A: Float + Zero + One, T: std::cmp::PartialEq, U: AsRef<[A]> + std::cmp::PartialEq> KdTree<A, T, U> {
40 pub fn new(dims: usize) -> Self {
42 KdTree::with_capacity(dims, 2_usize.pow(4))
43 }
44
45 pub fn with_capacity(dimensions: usize, capacity: usize) -> Self {
47 let min_bounds = vec![A::infinity(); dimensions];
48 let max_bounds = vec![A::neg_infinity(); dimensions];
49 KdTree {
50 left: None,
51 right: None,
52 dimensions,
53 capacity,
54 size: 0,
55 min_bounds: min_bounds.into_boxed_slice(),
56 max_bounds: max_bounds.into_boxed_slice(),
57 split_value: None,
58 split_dimension: None,
59 points: Some(vec![]),
60 bucket: Some(vec![]),
61 }
62 }
63
64 pub fn size(&self) -> usize {
65 self.size
66 }
67
68 pub fn nearest<F>(&self, point: &[A], num: usize, distance: &F) -> Result<Vec<(A, &T)>, ErrorKind>
69 where
70 F: Fn(&[A], &[A]) -> A,
71 {
72 self.check_point(point)?;
73 let num = std::cmp::min(num, self.size);
74 if num == 0 {
75 return Ok(vec![]);
76 }
77 let mut pending = BinaryHeap::new();
78 let mut evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
79 pending.push(HeapElement {
80 distance: A::zero(),
81 element: self,
82 });
83 while !pending.is_empty()
84 && (evaluated.len() < num || (-pending.peek().unwrap().distance <= evaluated.peek().unwrap().distance))
85 {
86 self.nearest_step(point, num, A::infinity(), distance, &mut pending, &mut evaluated);
87 }
88 Ok(evaluated
89 .into_sorted_vec()
90 .into_iter()
91 .take(num)
92 .map(Into::into)
93 .collect())
94 }
95
96 pub fn within<F>(&self, point: &[A], radius: A, distance: &F) -> Result<Vec<(A, &T)>, ErrorKind>
97 where
98 F: Fn(&[A], &[A]) -> A,
99 {
100 self.check_point(point)?;
101 if self.size == 0 {
102 return Ok(vec![]);
103 }
104 let mut pending = BinaryHeap::new();
105 let mut evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
106 pending.push(HeapElement {
107 distance: A::zero(),
108 element: self,
109 });
110 while !pending.is_empty() && (-pending.peek().unwrap().distance <= radius) {
111 self.nearest_step(point, self.size, radius, distance, &mut pending, &mut evaluated);
112 }
113 Ok(evaluated.into_sorted_vec().into_iter().map(Into::into).collect())
114 }
115
116 fn nearest_step<'b, F>(
117 &self,
118 point: &[A],
119 num: usize,
120 max_dist: A,
121 distance: &F,
122 pending: &mut BinaryHeap<HeapElement<A, &'b Self>>,
123 evaluated: &mut BinaryHeap<HeapElement<A, &'b T>>,
124 ) where
125 F: Fn(&[A], &[A]) -> A,
126 {
127 let mut curr = pending.pop().unwrap().element;
128 debug_assert!(evaluated.len() <= num);
129 let evaluated_dist = if evaluated.len() == num {
130 max_dist.min(evaluated.peek().unwrap().distance)
134 } else {
135 max_dist
136 };
137
138 while !curr.is_leaf() {
139 let candidate;
140 if curr.belongs_in_left(point) {
141 candidate = curr.right.as_ref().unwrap();
142 curr = curr.left.as_ref().unwrap();
143 } else {
144 candidate = curr.left.as_ref().unwrap();
145 curr = curr.right.as_ref().unwrap();
146 }
147 let candidate_to_space =
148 util::distance_to_space(point, &candidate.min_bounds, &candidate.max_bounds, distance);
149 if candidate_to_space <= evaluated_dist {
150 pending.push(HeapElement {
151 distance: candidate_to_space * -A::one(),
152 element: &**candidate,
153 });
154 }
155 }
156
157 let points = curr.points.as_ref().unwrap().iter();
158 let bucket = curr.bucket.as_ref().unwrap().iter();
159 let iter = points.zip(bucket).map(|(p, d)| HeapElement {
160 distance: distance(point, p.as_ref()),
161 element: d,
162 });
163 for element in iter {
164 if element <= max_dist {
165 if evaluated.len() < num {
166 evaluated.push(element);
167 } else if element < *evaluated.peek().unwrap() {
168 evaluated.pop();
169 evaluated.push(element);
170 }
171 }
172 }
173 }
174
175 pub fn iter_nearest<'a, 'b, F>(
176 &'b self,
177 point: &'a [A],
178 distance: &'a F,
179 ) -> Result<NearestIter<'a, 'b, A, T, U, F>, ErrorKind>
180 where
181 F: Fn(&[A], &[A]) -> A,
182 {
183 self.check_point(point)?;
184 let mut pending = BinaryHeap::new();
185 let evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
186 pending.push(HeapElement {
187 distance: A::zero(),
188 element: self,
189 });
190 Ok(NearestIter {
191 point,
192 pending,
193 evaluated,
194 distance,
195 })
196 }
197
198 pub fn iter_nearest_mut<'a, 'b, F>(
199 &'b mut self,
200 point: &'a [A],
201 distance: &'a F,
202 ) -> Result<NearestIterMut<'a, 'b, A, T, U, F>, ErrorKind>
203 where
204 F: Fn(&[A], &[A]) -> A,
205 {
206 self.check_point(point)?;
207 let mut pending = BinaryHeap::new();
208 let evaluated = BinaryHeap::<HeapElement<A, &mut T>>::new();
209 pending.push(HeapElement {
210 distance: A::zero(),
211 element: self,
212 });
213 Ok(NearestIterMut {
214 point,
215 pending,
216 evaluated,
217 distance,
218 })
219 }
220
221 pub fn add(&mut self, point: U, data: T) -> Result<(), ErrorKind> {
222 if self.capacity == 0 {
223 return Err(ErrorKind::ZeroCapacity);
224 }
225 self.check_point(point.as_ref())?;
226 self.add_unchecked(point, data)
227 }
228
229 fn add_unchecked(&mut self, point: U, data: T) -> Result<(), ErrorKind> {
230 if self.is_leaf() {
231 self.add_to_bucket(point, data);
232 return Ok(());
233 }
234 self.extend(point.as_ref());
235 self.size += 1;
236 let next = if self.belongs_in_left(point.as_ref()) {
237 self.left.as_mut()
238 } else {
239 self.right.as_mut()
240 };
241 next.unwrap().add_unchecked(point, data)
242 }
243
244 fn add_to_bucket(&mut self, point: U, data: T) {
245 self.extend(point.as_ref());
246 let mut points = self.points.take().unwrap();
247 let mut bucket = self.bucket.take().unwrap();
248 points.push(point);
249 bucket.push(data);
250 self.size += 1;
251 if self.size > self.capacity {
252 self.split(points, bucket);
253 } else {
254 self.points = Some(points);
255 self.bucket = Some(bucket);
256 }
257 }
258
259 pub fn remove(&mut self, point: &U, data: &T) -> Result<usize, ErrorKind> {
260 let mut removed = 0;
261 self.check_point(point.as_ref())?;
262 if let (Some(mut points), Some(mut bucket)) = (self.points.take(), self.bucket.take()) {
263 while let Some(p_index) = points.iter().position(|x| x == point) {
264 if &bucket[p_index] == data {
265 points.remove(p_index);
266 bucket.remove(p_index);
267 removed += 1;
268 self.size -= 1;
269 }
270 }
271 self.points = Some(points);
272 self.bucket = Some(bucket);
273 } else {
274 if let Some(right) = self.right.as_mut() {
275 let right_removed = right.remove(point, data)?;
276 if right_removed > 0 {
277 self.size -= right_removed;
278 removed += right_removed;
279 }
280 }
281 if let Some(left) = self.left.as_mut() {
282 let left_removed = left.remove(point, data)?;
283 if left_removed > 0 {
284 self.size -= left_removed;
285 removed += left_removed;
286 }
287 }
288 }
289 Ok(removed)
290 }
291
292 fn split(&mut self, mut points: Vec<U>, mut bucket: Vec<T>) {
293 let mut max = A::zero();
294 for dim in 0..self.dimensions {
295 let diff = self.max_bounds[dim] - self.min_bounds[dim];
296 if !diff.is_nan() && diff > max {
297 max = diff;
298 self.split_dimension = Some(dim);
299 }
300 }
301 match self.split_dimension {
302 None => {
303 self.points = Some(points);
304 self.bucket = Some(bucket);
305 return;
306 }
307 Some(dim) => {
308 let min = self.min_bounds[dim];
309 let max = self.max_bounds[dim];
310 self.split_value = Some(min + (max - min) / A::from(2.0).unwrap());
311 }
312 };
313 let mut left = Box::new(KdTree::with_capacity(self.dimensions, self.capacity));
314 let mut right = Box::new(KdTree::with_capacity(self.dimensions, self.capacity));
315 while !points.is_empty() {
316 let point = points.swap_remove(0);
317 let data = bucket.swap_remove(0);
318 if self.belongs_in_left(point.as_ref()) {
319 left.add_to_bucket(point, data);
320 } else {
321 right.add_to_bucket(point, data);
322 }
323 }
324 self.left = Some(left);
325 self.right = Some(right);
326 }
327
328 fn belongs_in_left(&self, point: &[A]) -> bool {
329 if self.min_bounds[self.split_dimension.unwrap()] == self.split_value.unwrap() {
330 point[self.split_dimension.unwrap()] <= self.split_value.unwrap()
331 } else {
332 point[self.split_dimension.unwrap()] < self.split_value.unwrap()
333 }
334 }
335
336 fn extend(&mut self, point: &[A]) {
337 let min = self.min_bounds.iter_mut();
338 let max = self.max_bounds.iter_mut();
339 for ((l, h), v) in min.zip(max).zip(point.iter()) {
340 if v < l {
341 *l = *v
342 }
343 if v > h {
344 *h = *v
345 }
346 }
347 }
348
349 fn is_leaf(&self) -> bool {
350 self.bucket.is_some()
351 && self.points.is_some()
352 && self.split_value.is_none()
353 && self.split_dimension.is_none()
354 && self.left.is_none()
355 && self.right.is_none()
356 }
357
358 fn check_point(&self, point: &[A]) -> Result<(), ErrorKind> {
359 if self.dimensions != point.len() {
360 return Err(ErrorKind::WrongDimension);
361 }
362 for n in point {
363 if !n.is_finite() {
364 return Err(ErrorKind::NonFiniteCoordinate);
365 }
366 }
367 Ok(())
368 }
369}
370
371pub struct NearestIter<
372 'a,
373 'b,
374 A: 'a + 'b + Float,
375 T: 'b + PartialEq,
376 U: 'b + AsRef<[A]> + std::cmp::PartialEq,
377 F: 'a + Fn(&[A], &[A]) -> A,
378> {
379 point: &'a [A],
380 pending: BinaryHeap<HeapElement<A, &'b KdTree<A, T, U>>>,
381 evaluated: BinaryHeap<HeapElement<A, &'b T>>,
382 distance: &'a F,
383}
384
385impl<'a, 'b, A: Float + Zero + One, T: 'b, U: 'b + AsRef<[A]>, F: 'a> Iterator for NearestIter<'a, 'b, A, T, U, F>
386where
387 F: Fn(&[A], &[A]) -> A,
388 U: PartialEq,
389 T: PartialEq,
390{
391 type Item = (A, &'b T);
392 fn next(&mut self) -> Option<(A, &'b T)> {
393 use util::distance_to_space;
394
395 let distance = self.distance;
396 let point = self.point;
397 while !self.pending.is_empty()
398 && (self.evaluated.peek().map_or(A::infinity(), |x| -x.distance) >= -self.pending.peek().unwrap().distance)
399 {
400 let mut curr = self.pending.pop().unwrap().element;
401 while !curr.is_leaf() {
402 let candidate;
403 if curr.belongs_in_left(point) {
404 candidate = curr.right.as_ref().unwrap();
405 curr = curr.left.as_ref().unwrap();
406 } else {
407 candidate = curr.left.as_ref().unwrap();
408 curr = curr.right.as_ref().unwrap();
409 }
410 self.pending.push(HeapElement {
411 distance: -distance_to_space(point, &candidate.min_bounds, &candidate.max_bounds, distance),
412 element: &**candidate,
413 });
414 }
415 let points = curr.points.as_ref().unwrap().iter();
416 let bucket = curr.bucket.as_ref().unwrap().iter();
417 self.evaluated.extend(points.zip(bucket).map(|(p, d)| HeapElement {
418 distance: -distance(point, p.as_ref()),
419 element: d,
420 }));
421 }
422 self.evaluated.pop().map(|x| (-x.distance, x.element))
423 }
424}
425
426pub struct NearestIterMut<
427 'a,
428 'b,
429 A: 'a + 'b + Float,
430 T: 'b + PartialEq,
431 U: 'b + AsRef<[A]> + PartialEq,
432 F: 'a + Fn(&[A], &[A]) -> A,
433> {
434 point: &'a [A],
435 pending: BinaryHeap<HeapElement<A, &'b mut KdTree<A, T, U>>>,
436 evaluated: BinaryHeap<HeapElement<A, &'b mut T>>,
437 distance: &'a F,
438}
439
440impl<'a, 'b, A: Float + Zero + One, T: 'b, U: 'b + AsRef<[A]>, F: 'a> Iterator for NearestIterMut<'a, 'b, A, T, U, F>
441where
442 F: Fn(&[A], &[A]) -> A,
443 U: PartialEq,
444 T: PartialEq,
445{
446 type Item = (A, &'b mut T);
447 fn next(&mut self) -> Option<(A, &'b mut T)> {
448 use util::distance_to_space;
449
450 let distance = self.distance;
451 let point = self.point;
452 while !self.pending.is_empty()
453 && (self.evaluated.peek().map_or(A::infinity(), |x| -x.distance) >= -self.pending.peek().unwrap().distance)
454 {
455 let mut curr = &mut *self.pending.pop().unwrap().element;
456 while !curr.is_leaf() {
457 let candidate;
458 if curr.belongs_in_left(point) {
459 candidate = curr.right.as_mut().unwrap();
460 curr = curr.left.as_mut().unwrap();
461 } else {
462 candidate = curr.left.as_mut().unwrap();
463 curr = curr.right.as_mut().unwrap();
464 }
465 self.pending.push(HeapElement {
466 distance: -distance_to_space(point, &candidate.min_bounds, &candidate.max_bounds, distance),
467 element: &mut **candidate,
468 });
469 }
470 let points = curr.points.as_ref().unwrap().iter();
471 let bucket = curr.bucket.as_mut().unwrap().iter_mut();
472 self.evaluated.extend(points.zip(bucket).map(|(p, d)| HeapElement {
473 distance: -distance(point, p.as_ref()),
474 element: d,
475 }));
476 }
477 self.evaluated.pop().map(|x| (-x.distance, x.element))
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 extern crate rand;
484 use super::KdTree;
485
486 fn random_point() -> ([f64; 2], i32) {
487 rand::random::<([f64; 2], i32)>()
488 }
489
490 #[test]
491 fn it_has_default_capacity() {
492 let tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
493 assert_eq!(tree.capacity, 2_usize.pow(4));
494 }
495
496 #[test]
497 fn it_can_be_cloned() {
498 let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
499 let (pos, data) = random_point();
500 tree.add(pos, data).unwrap();
501 let mut cloned_tree = tree.clone();
502 cloned_tree.add(pos, data).unwrap();
503 assert_eq!(tree.size(), 1);
504 assert_eq!(cloned_tree.size(), 2);
505 }
506
507 #[test]
508 fn it_holds_on_to_its_capacity_before_splitting() {
509 let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
510 let capacity = 2_usize.pow(4);
511 for _ in 0..capacity {
512 let (pos, data) = random_point();
513 tree.add(pos, data).unwrap();
514 }
515 assert_eq!(tree.size, capacity);
516 assert_eq!(tree.size(), capacity);
517 assert!(tree.left.is_none() && tree.right.is_none());
518 {
519 let (pos, data) = random_point();
520 tree.add(pos, data).unwrap();
521 }
522 assert_eq!(tree.size, capacity + 1);
523 assert_eq!(tree.size(), capacity + 1);
524 assert!(tree.left.is_some() && tree.right.is_some());
525 }
526
527 #[test]
528 fn no_items_can_be_added_to_a_zero_capacity_kdtree() {
529 let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::with_capacity(2, 0);
530 let (pos, data) = random_point();
531 let res = tree.add(pos, data);
532 assert!(res.is_err());
533 }
534
535 #[test]
536 fn avoid_infinite_call_loop_between_add_to_bucket_and_split_due_to_float_accuracy() {
537 {
538 let min = 0.47945351705599926f64;
539 let max = 0.479_453_517_055_999_3_f64;
540
541 let mut tree = KdTree::with_capacity(1, 2);
542 tree.add([min], ()).unwrap();
543 tree.add([max], ()).unwrap();
544
545 tree.add([min], ()).unwrap();
546 tree.add([max], ()).unwrap();
547 }
548 {
549 let min = -0.479_453_517_055_999_3_f64;
550 let max = -0.47945351705599926f64;
551
552 let mut tree = KdTree::with_capacity(1, 2);
553 tree.add([min], ()).unwrap();
554 tree.add([max], ()).unwrap();
555
556 tree.add([min], ()).unwrap();
557 tree.add([max], ()).unwrap();
558 }
559 }
560}