1#![doc = include_str!("../README.md")]
18#![warn(missing_debug_implementations, missing_docs, rust_2018_idioms)]
19
20use kdtree::distance::squared_euclidean;
21use num_traits::float::Float;
22use num_traits::identities::Zero;
23use rand::distributions::{Distribution, Uniform};
24use std::fmt::Debug;
25use std::mem;
26use tracing::debug;
27
28#[derive(Debug)]
29enum ExtendStatus {
30 Reached(usize),
31 Advanced(usize),
32 Trapped,
33}
34
35#[derive(Debug, Clone)]
37struct Node<T> {
38 parent_index: Option<usize>,
39 data: T,
40}
41
42impl<T> Node<T> {
43 fn new(data: T) -> Self {
44 Node {
45 parent_index: None,
46 data,
47 }
48 }
49}
50
51#[derive(Debug)]
53struct Tree<N>
54where
55 N: Float + Zero + Debug,
56{
57 kdtree: kdtree::KdTree<N, usize, Vec<N>>,
58 vertices: Vec<Node<Vec<N>>>,
59 name: &'static str,
60}
61
62impl<N> Tree<N>
63where
64 N: Float + Zero + Debug,
65{
66 fn new(name: &'static str, dim: usize) -> Self {
67 Tree {
68 kdtree: kdtree::KdTree::new(dim),
69 vertices: Vec::new(),
70 name,
71 }
72 }
73 fn add_vertex(&mut self, q: &[N]) -> usize {
74 let index = self.vertices.len();
75 self.kdtree.add(q.to_vec(), index).unwrap();
76 self.vertices.push(Node::new(q.to_vec()));
77 index
78 }
79 fn add_edge(&mut self, q1_index: usize, q2_index: usize) {
80 self.vertices[q2_index].parent_index = Some(q1_index);
81 }
82 fn get_nearest_index(&self, q: &[N]) -> usize {
83 *self.kdtree.nearest(q, 1, &squared_euclidean).unwrap()[0].1
84 }
85 fn extend<FF>(&mut self, q_target: &[N], extend_length: N, is_free: &mut FF) -> ExtendStatus
86 where
87 FF: FnMut(&[N]) -> bool,
88 {
89 assert!(extend_length > N::zero());
90 let nearest_index = self.get_nearest_index(q_target);
91 let nearest_q = &self.vertices[nearest_index].data;
92 let diff_dist = squared_euclidean(q_target, nearest_q).sqrt();
93 let q_new = if diff_dist < extend_length {
94 q_target.to_vec()
95 } else {
96 nearest_q
97 .iter()
98 .zip(q_target)
99 .map(|(near, target)| *near + (*target - *near) * extend_length / diff_dist)
100 .collect::<Vec<_>>()
101 };
102 debug!("q_new={q_new:?}");
103 if is_free(&q_new) {
104 let new_index = self.add_vertex(&q_new);
105 self.add_edge(nearest_index, new_index);
106 if squared_euclidean(&q_new, q_target).sqrt() < extend_length {
107 return ExtendStatus::Reached(new_index);
108 }
109 debug!("target = {q_target:?}");
110 debug!("advanced to {q_target:?}");
111 return ExtendStatus::Advanced(new_index);
112 }
113 ExtendStatus::Trapped
114 }
115 fn connect<FF>(&mut self, q_target: &[N], extend_length: N, is_free: &mut FF) -> ExtendStatus
116 where
117 FF: FnMut(&[N]) -> bool,
118 {
119 loop {
120 debug!("connecting...{q_target:?}");
121 match self.extend(q_target, extend_length, is_free) {
122 ExtendStatus::Trapped => return ExtendStatus::Trapped,
123 ExtendStatus::Reached(index) => return ExtendStatus::Reached(index),
124 ExtendStatus::Advanced(_) => {}
125 };
126 }
127 }
128 fn get_until_root(&self, index: usize) -> Vec<Vec<N>> {
129 let mut nodes = Vec::new();
130 let mut cur_index = index;
131 while let Some(parent_index) = self.vertices[cur_index].parent_index {
132 cur_index = parent_index;
133 nodes.push(self.vertices[cur_index].data.clone())
134 }
135 nodes
136 }
137}
138
139pub fn dual_rrt_connect<FF, FR, N>(
141 start: &[N],
142 goal: &[N],
143 mut is_free: FF,
144 random_sample: FR,
145 extend_length: N,
146 num_max_try: usize,
147) -> Result<Vec<Vec<N>>, String>
148where
149 FF: FnMut(&[N]) -> bool,
150 FR: Fn() -> Vec<N>,
151 N: Float + Debug,
152{
153 assert_eq!(start.len(), goal.len());
154 let mut tree_a = Tree::new("start", start.len());
155 let mut tree_b = Tree::new("goal", start.len());
156 tree_a.add_vertex(start);
157 tree_b.add_vertex(goal);
158 for _ in 0..num_max_try {
159 debug!("tree_a = {:?}", tree_a.vertices.len());
160 debug!("tree_b = {:?}", tree_b.vertices.len());
161 let q_rand = random_sample();
162 let extend_status = tree_a.extend(&q_rand, extend_length, &mut is_free);
163 match extend_status {
164 ExtendStatus::Trapped => {}
165 ExtendStatus::Advanced(new_index) | ExtendStatus::Reached(new_index) => {
166 let q_new = &tree_a.vertices[new_index].data;
167 if let ExtendStatus::Reached(reach_index) =
168 tree_b.connect(q_new, extend_length, &mut is_free)
169 {
170 let mut a_all = tree_a.get_until_root(new_index);
171 let mut b_all = tree_b.get_until_root(reach_index);
172 a_all.reverse();
173 a_all.append(&mut b_all);
174 if tree_b.name == "start" {
175 a_all.reverse();
176 }
177 return Ok(a_all);
178 }
179 }
180 }
181 mem::swap(&mut tree_a, &mut tree_b);
182 }
183 Err("failed".to_string())
184}
185
186pub fn smooth_path<FF, N>(
188 path: &mut Vec<Vec<N>>,
189 mut is_free: FF,
190 extend_length: N,
191 num_max_try: usize,
192) where
193 FF: FnMut(&[N]) -> bool,
194 N: Float + Debug,
195{
196 if path.len() < 3 {
197 return;
198 }
199 let mut rng = rand::thread_rng();
200 for _ in 0..num_max_try {
201 let range1 = Uniform::new(0, path.len() - 2);
202 let ind1 = range1.sample(&mut rng);
203 let range2 = Uniform::new(ind1 + 2, path.len());
204 let ind2 = range2.sample(&mut rng);
205 let mut base_point = path[ind1].clone();
206 let point2 = path[ind2].clone();
207 let mut is_searching = true;
208 while is_searching {
209 let diff_dist = squared_euclidean(&base_point, &point2).sqrt();
210 if diff_dist < extend_length {
211 let remove_index = ind1 + 1;
214 for _ in 0..(ind2 - ind1 - 1) {
215 path.remove(remove_index);
216 }
217 if path.len() == 2 {
218 return;
219 }
220 is_searching = false;
221 } else {
222 let check_point = base_point
223 .iter()
224 .zip(point2.iter())
225 .map(|(near, target)| *near + (*target - *near) * extend_length / diff_dist)
226 .collect::<Vec<_>>();
227 if !is_free(&check_point) {
228 is_searching = false;
230 } else {
231 base_point = check_point;
233 }
234 }
235 }
236 }
237}
238
239#[test]
240fn it_works() {
241 use rand::distributions::{Distribution, Uniform};
242 let mut result = dual_rrt_connect(
243 &[-1.2, 0.0],
244 &[1.2, 0.0],
245 |p: &[f64]| !(p[0].abs() < 1.0 && p[1].abs() < 1.0),
246 || {
247 let between = Uniform::new(-2.0, 2.0);
248 let mut rng = rand::thread_rng();
249 vec![between.sample(&mut rng), between.sample(&mut rng)]
250 },
251 0.2,
252 1000,
253 )
254 .unwrap();
255 println!("{result:?}");
256 assert!(result.len() >= 4);
257 smooth_path(
258 &mut result,
259 |p: &[f64]| !(p[0].abs() < 1.0 && p[1].abs() < 1.0),
260 0.2,
261 100,
262 );
263 println!("{result:?}");
264 assert!(result.len() >= 3);
265}