rrt/
lib.rs

1/*
2  Copyright 2017 Takashi Ogura
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8      http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15*/
16
17#![doc = include_str!("../README.md")]
18#![warn(missing_debug_implementations, missing_docs, rust_2018_idioms)]
19
20use kdtree::distance::squared_euclidean;
21use num_traits::float::Float;
22use num_traits::identities::Zero;
23use rand::distributions::{Distribution, Uniform};
24use std::fmt::Debug;
25use std::mem;
26use tracing::debug;
27
28#[derive(Debug)]
29enum ExtendStatus {
30    Reached(usize),
31    Advanced(usize),
32    Trapped,
33}
34
35/// Node that contains user data
36#[derive(Debug, Clone)]
37struct Node<T> {
38    parent_index: Option<usize>,
39    data: T,
40}
41
42impl<T> Node<T> {
43    fn new(data: T) -> Self {
44        Node {
45            parent_index: None,
46            data,
47        }
48    }
49}
50
51/// RRT
52#[derive(Debug)]
53struct Tree<N>
54where
55    N: Float + Zero + Debug,
56{
57    kdtree: kdtree::KdTree<N, usize, Vec<N>>,
58    vertices: Vec<Node<Vec<N>>>,
59    name: &'static str,
60}
61
62impl<N> Tree<N>
63where
64    N: Float + Zero + Debug,
65{
66    fn new(name: &'static str, dim: usize) -> Self {
67        Tree {
68            kdtree: kdtree::KdTree::new(dim),
69            vertices: Vec::new(),
70            name,
71        }
72    }
73    fn add_vertex(&mut self, q: &[N]) -> usize {
74        let index = self.vertices.len();
75        self.kdtree.add(q.to_vec(), index).unwrap();
76        self.vertices.push(Node::new(q.to_vec()));
77        index
78    }
79    fn add_edge(&mut self, q1_index: usize, q2_index: usize) {
80        self.vertices[q2_index].parent_index = Some(q1_index);
81    }
82    fn get_nearest_index(&self, q: &[N]) -> usize {
83        *self.kdtree.nearest(q, 1, &squared_euclidean).unwrap()[0].1
84    }
85    fn extend<FF>(&mut self, q_target: &[N], extend_length: N, is_free: &mut FF) -> ExtendStatus
86    where
87        FF: FnMut(&[N]) -> bool,
88    {
89        assert!(extend_length > N::zero());
90        let nearest_index = self.get_nearest_index(q_target);
91        let nearest_q = &self.vertices[nearest_index].data;
92        let diff_dist = squared_euclidean(q_target, nearest_q).sqrt();
93        let q_new = if diff_dist < extend_length {
94            q_target.to_vec()
95        } else {
96            nearest_q
97                .iter()
98                .zip(q_target)
99                .map(|(near, target)| *near + (*target - *near) * extend_length / diff_dist)
100                .collect::<Vec<_>>()
101        };
102        debug!("q_new={q_new:?}");
103        if is_free(&q_new) {
104            let new_index = self.add_vertex(&q_new);
105            self.add_edge(nearest_index, new_index);
106            if squared_euclidean(&q_new, q_target).sqrt() < extend_length {
107                return ExtendStatus::Reached(new_index);
108            }
109            debug!("target = {q_target:?}");
110            debug!("advanced to {q_target:?}");
111            return ExtendStatus::Advanced(new_index);
112        }
113        ExtendStatus::Trapped
114    }
115    fn connect<FF>(&mut self, q_target: &[N], extend_length: N, is_free: &mut FF) -> ExtendStatus
116    where
117        FF: FnMut(&[N]) -> bool,
118    {
119        loop {
120            debug!("connecting...{q_target:?}");
121            match self.extend(q_target, extend_length, is_free) {
122                ExtendStatus::Trapped => return ExtendStatus::Trapped,
123                ExtendStatus::Reached(index) => return ExtendStatus::Reached(index),
124                ExtendStatus::Advanced(_) => {}
125            };
126        }
127    }
128    fn get_until_root(&self, index: usize) -> Vec<Vec<N>> {
129        let mut nodes = Vec::new();
130        let mut cur_index = index;
131        while let Some(parent_index) = self.vertices[cur_index].parent_index {
132            cur_index = parent_index;
133            nodes.push(self.vertices[cur_index].data.clone())
134        }
135        nodes
136    }
137}
138
139/// search the path from start to goal which is free, using random_sample function
140pub fn dual_rrt_connect<FF, FR, N>(
141    start: &[N],
142    goal: &[N],
143    mut is_free: FF,
144    random_sample: FR,
145    extend_length: N,
146    num_max_try: usize,
147) -> Result<Vec<Vec<N>>, String>
148where
149    FF: FnMut(&[N]) -> bool,
150    FR: Fn() -> Vec<N>,
151    N: Float + Debug,
152{
153    assert_eq!(start.len(), goal.len());
154    let mut tree_a = Tree::new("start", start.len());
155    let mut tree_b = Tree::new("goal", start.len());
156    tree_a.add_vertex(start);
157    tree_b.add_vertex(goal);
158    for _ in 0..num_max_try {
159        debug!("tree_a = {:?}", tree_a.vertices.len());
160        debug!("tree_b = {:?}", tree_b.vertices.len());
161        let q_rand = random_sample();
162        let extend_status = tree_a.extend(&q_rand, extend_length, &mut is_free);
163        match extend_status {
164            ExtendStatus::Trapped => {}
165            ExtendStatus::Advanced(new_index) | ExtendStatus::Reached(new_index) => {
166                let q_new = &tree_a.vertices[new_index].data;
167                if let ExtendStatus::Reached(reach_index) =
168                    tree_b.connect(q_new, extend_length, &mut is_free)
169                {
170                    let mut a_all = tree_a.get_until_root(new_index);
171                    let mut b_all = tree_b.get_until_root(reach_index);
172                    a_all.reverse();
173                    a_all.append(&mut b_all);
174                    if tree_b.name == "start" {
175                        a_all.reverse();
176                    }
177                    return Ok(a_all);
178                }
179            }
180        }
181        mem::swap(&mut tree_a, &mut tree_b);
182    }
183    Err("failed".to_string())
184}
185
186/// select random two points, and try to connect.
187pub fn smooth_path<FF, N>(
188    path: &mut Vec<Vec<N>>,
189    mut is_free: FF,
190    extend_length: N,
191    num_max_try: usize,
192) where
193    FF: FnMut(&[N]) -> bool,
194    N: Float + Debug,
195{
196    if path.len() < 3 {
197        return;
198    }
199    let mut rng = rand::thread_rng();
200    for _ in 0..num_max_try {
201        let range1 = Uniform::new(0, path.len() - 2);
202        let ind1 = range1.sample(&mut rng);
203        let range2 = Uniform::new(ind1 + 2, path.len());
204        let ind2 = range2.sample(&mut rng);
205        let mut base_point = path[ind1].clone();
206        let point2 = path[ind2].clone();
207        let mut is_searching = true;
208        while is_searching {
209            let diff_dist = squared_euclidean(&base_point, &point2).sqrt();
210            if diff_dist < extend_length {
211                // reached!
212                // remove path[ind1+1] ... path[ind2-1]
213                let remove_index = ind1 + 1;
214                for _ in 0..(ind2 - ind1 - 1) {
215                    path.remove(remove_index);
216                }
217                if path.len() == 2 {
218                    return;
219                }
220                is_searching = false;
221            } else {
222                let check_point = base_point
223                    .iter()
224                    .zip(point2.iter())
225                    .map(|(near, target)| *near + (*target - *near) * extend_length / diff_dist)
226                    .collect::<Vec<_>>();
227                if !is_free(&check_point) {
228                    // trapped
229                    is_searching = false;
230                } else {
231                    // continue to extend
232                    base_point = check_point;
233                }
234            }
235        }
236    }
237}
238
239#[test]
240fn it_works() {
241    use rand::distributions::{Distribution, Uniform};
242    let mut result = dual_rrt_connect(
243        &[-1.2, 0.0],
244        &[1.2, 0.0],
245        |p: &[f64]| !(p[0].abs() < 1.0 && p[1].abs() < 1.0),
246        || {
247            let between = Uniform::new(-2.0, 2.0);
248            let mut rng = rand::thread_rng();
249            vec![between.sample(&mut rng), between.sample(&mut rng)]
250        },
251        0.2,
252        1000,
253    )
254    .unwrap();
255    println!("{result:?}");
256    assert!(result.len() >= 4);
257    smooth_path(
258        &mut result,
259        |p: &[f64]| !(p[0].abs() < 1.0 && p[1].abs() < 1.0),
260        0.2,
261        100,
262    );
263    println!("{result:?}");
264    assert!(result.len() >= 3);
265}