openrr_planner/collision/
self_collision_checker.rs

1use std::{path::Path, sync::Arc, time::Duration};
2
3use k::nalgebra as na;
4use na::RealField;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use tracing::debug;
8
9use crate::{
10    CollisionDetector, TrajectoryPoint,
11    collision::{RobotCollisionDetector, parse_colon_separated_pairs},
12    errors::*,
13    funcs::create_chain_from_joint_names,
14    interpolate,
15};
16
17pub struct SelfCollisionChecker<N>
18where
19    N: RealField + Copy + k::SubsetOf<f64>,
20{
21    /// Robot for reference (read only and assumed to hold the latest full states)
22    reference_robot: Arc<k::Chain<N>>,
23    /// Robot collision detector
24    robot_collision_detector: RobotCollisionDetector<N>,
25    /// Rate of time interpolation
26    time_interpolate_rate: N,
27}
28
29impl<N> SelfCollisionChecker<N>
30where
31    N: RealField + k::SubsetOf<f64> + num_traits::Float,
32{
33    #[track_caller]
34    pub fn new(
35        reference_robot: Arc<k::Chain<N>>,
36        robot_collision_detector: RobotCollisionDetector<N>,
37        time_interpolate_rate: N,
38    ) -> Self {
39        assert!(
40            time_interpolate_rate > na::convert(0.0) && time_interpolate_rate <= na::convert(1.0),
41            "time_interpolate_rate must be 0.0~1.0 but {time_interpolate_rate}",
42        );
43
44        Self {
45            reference_robot,
46            robot_collision_detector,
47            time_interpolate_rate,
48        }
49    }
50
51    pub fn check_joint_positions(
52        &self,
53        current: &[N],
54        positions: &[N],
55        duration: std::time::Duration,
56    ) -> Result<()> {
57        self.check_partial_joint_positions_inner(None, current, positions, duration)
58    }
59
60    pub fn check_partial_joint_positions(
61        &self,
62        using_joint_names: &[String],
63        current: &[N],
64        positions: &[N],
65        duration: std::time::Duration,
66    ) -> Result<()> {
67        self.check_partial_joint_positions_inner(
68            Some(using_joint_names),
69            current,
70            positions,
71            duration,
72        )
73    }
74
75    fn check_partial_joint_positions_inner(
76        &self,
77        using_joint_names: Option<&[String]>,
78        current: &[N],
79        positions: &[N],
80        duration: std::time::Duration,
81    ) -> Result<()> {
82        let duration_f64 = num_traits::NumCast::from::<f64>(duration.as_secs_f64()).unwrap();
83        match interpolate(
84            &[current.to_vec(), positions.to_vec()],
85            duration_f64,
86            self.time_interpolate_rate.mul(duration_f64),
87        ) {
88            Some(interpolated) => {
89                debug!("interpolated len={}", interpolated.len());
90                self.check_partial_joint_trajectory_inner(using_joint_names, &interpolated)
91            }
92            None => Err(Error::InterpolationError(
93                "failed to interpolate".to_owned(),
94            )),
95        }
96    }
97
98    pub fn check_joint_trajectory(&self, trajectory: &[TrajectoryPoint<N>]) -> Result<()> {
99        self.check_partial_joint_trajectory_inner(None, trajectory)
100    }
101
102    pub fn check_partial_joint_trajectory(
103        &self,
104        using_joint_names: &[String],
105        trajectory: &[TrajectoryPoint<N>],
106    ) -> Result<()> {
107        self.check_partial_joint_trajectory_inner(Some(using_joint_names), trajectory)
108    }
109
110    fn check_partial_joint_trajectory_inner(
111        &self,
112        using_joint_names: Option<&[String]>,
113        trajectory: &[TrajectoryPoint<N>],
114    ) -> Result<()> {
115        // Synchronize with the reference robot states for joints not included using_joints
116        self.collision_check_robot()
117            .set_joint_positions_clamped(self.reference_robot.joint_positions().as_slice());
118
119        let using_joints = match using_joint_names {
120            Some(joint_names) => {
121                create_chain_from_joint_names(self.collision_check_robot(), joint_names).unwrap()
122            }
123            None => {
124                let nodes = self
125                    .collision_check_robot()
126                    .iter()
127                    .map(|node| (*node).clone())
128                    .collect::<Vec<k::Node<N>>>();
129                k::Chain::from_nodes(nodes)
130            }
131        };
132
133        // Check the partial trajectory
134        let last_index = trajectory.len() - 1;
135        for (i, v) in trajectory.iter().enumerate() {
136            using_joints.set_joint_positions(&v.position)?;
137            self.collision_check_robot().update_transforms();
138
139            let mut self_checker = self.robot_collision_detector.detect_self();
140            if let Some(names) = self_checker.next() {
141                return Err(Error::Collision {
142                    point: match i {
143                        0 => UnfeasibleTrajectoryPoint::Start,
144                        index if index == last_index => UnfeasibleTrajectoryPoint::Goal,
145                        _ => UnfeasibleTrajectoryPoint::WayPoint,
146                    },
147                    collision_link_names: vec![names.0, names.1],
148                });
149            }
150
151            // Summarize the calculation time
152            let mut vec_used: Vec<_> = self_checker.used_duration().iter().collect();
153            vec_used.sort_by(|a, b| b.1.cmp(a.1));
154            let sum_duration: Duration = self_checker.used_duration().values().sum();
155            debug!("total: {sum_duration:?}");
156            debug!("detailed: {vec_used:?}");
157        }
158        Ok(())
159    }
160
161    /// Get the robot model used for collision checking
162    pub fn collision_check_robot(&self) -> &k::Chain<N> {
163        &self.robot_collision_detector.robot
164    }
165}
166
167#[derive(Clone, Serialize, Deserialize, Debug, JsonSchema)]
168#[serde(deny_unknown_fields)]
169pub struct SelfCollisionCheckerConfig {
170    #[serde(default = "default_prediction")]
171    pub prediction: f64,
172    #[serde(default = "default_time_interpolate_rate")]
173    pub time_interpolate_rate: f64,
174}
175
176fn default_prediction() -> f64 {
177    0.001
178}
179
180fn default_time_interpolate_rate() -> f64 {
181    0.5
182}
183
184impl Default for SelfCollisionCheckerConfig {
185    fn default() -> Self {
186        Self {
187            prediction: default_prediction(),
188            time_interpolate_rate: default_time_interpolate_rate(),
189        }
190    }
191}
192
193pub fn create_self_collision_checker<P: AsRef<Path>>(
194    urdf_path: P,
195    self_collision_check_pairs: &[String],
196    config: &SelfCollisionCheckerConfig,
197    robot: Arc<k::Chain<f64>>,
198) -> SelfCollisionChecker<f64> {
199    let collision_detector = CollisionDetector::from_urdf_robot_with_base_dir(
200        &urdf_rs::utils::read_urdf_or_xacro(urdf_path.as_ref()).unwrap(),
201        urdf_path.as_ref().parent(),
202        config.prediction,
203    );
204    let robot_collision_detector = RobotCollisionDetector::new(
205        (*robot).clone(),
206        collision_detector,
207        parse_colon_separated_pairs(self_collision_check_pairs).unwrap(),
208    );
209
210    SelfCollisionChecker::new(
211        robot,
212        robot_collision_detector,
213        config.time_interpolate_rate,
214    )
215}
216
217#[test]
218fn test_create_self_collision_checker() {
219    let urdf_path = Path::new("sample.urdf");
220    let robot = Arc::new(k::Chain::from_urdf_file(urdf_path).unwrap());
221    let self_collision_checker = create_self_collision_checker(
222        urdf_path,
223        &["root:l_shoulder_roll".into()],
224        &SelfCollisionCheckerConfig::default(),
225        robot,
226    );
227
228    assert!(
229        self_collision_checker
230            .check_joint_positions(&[0.0; 16], &[0.0; 16], std::time::Duration::new(1, 0),)
231            .is_ok()
232    );
233    assert!(
234        self_collision_checker
235            .check_joint_positions(
236                &[0.0; 16],
237                &[
238                    1.57, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
239                ],
240                std::time::Duration::new(1, 0),
241            )
242            .is_err()
243    );
244
245    let l_shoulder_yaw_node = self_collision_checker
246        .collision_check_robot()
247        .find("l_shoulder_yaw")
248        .unwrap();
249    let using_joints = k::SerialChain::from_end(l_shoulder_yaw_node);
250    let using_joint_names = using_joints
251        .iter_joints()
252        .map(|j| j.name.to_owned())
253        .collect::<Vec<String>>();
254
255    assert!(
256        self_collision_checker
257            .check_partial_joint_positions(
258                using_joint_names.as_slice(),
259                &[0.0],
260                &[0.0],
261                std::time::Duration::new(1, 0),
262            )
263            .is_ok()
264    );
265    assert!(
266        self_collision_checker
267            .check_partial_joint_positions(
268                using_joint_names.as_slice(),
269                &[0.0],
270                &[1.57],
271                std::time::Duration::new(1, 0),
272            )
273            .is_err()
274    );
275}