openrr_planner/collision/
self_collision_checker.rs

1use std::{path::Path, sync::Arc, time::Duration};
2
3use k::nalgebra as na;
4use na::RealField;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use tracing::debug;
8
9use crate::{
10    collision::{parse_colon_separated_pairs, RobotCollisionDetector},
11    errors::*,
12    funcs::create_chain_from_joint_names,
13    interpolate, CollisionDetector, TrajectoryPoint,
14};
15
16pub struct SelfCollisionChecker<N>
17where
18    N: RealField + Copy + k::SubsetOf<f64>,
19{
20    /// Robot for reference (read only and assumed to hold the latest full states)
21    reference_robot: Arc<k::Chain<N>>,
22    /// Robot collision detector
23    robot_collision_detector: RobotCollisionDetector<N>,
24    /// Rate of time interpolation
25    time_interpolate_rate: N,
26}
27
28impl<N> SelfCollisionChecker<N>
29where
30    N: RealField + k::SubsetOf<f64> + num_traits::Float,
31{
32    #[track_caller]
33    pub fn new(
34        reference_robot: Arc<k::Chain<N>>,
35        robot_collision_detector: RobotCollisionDetector<N>,
36        time_interpolate_rate: N,
37    ) -> Self {
38        assert!(
39            time_interpolate_rate > na::convert(0.0) && time_interpolate_rate <= na::convert(1.0),
40            "time_interpolate_rate must be 0.0~1.0 but {time_interpolate_rate}",
41        );
42
43        Self {
44            reference_robot,
45            robot_collision_detector,
46            time_interpolate_rate,
47        }
48    }
49
50    pub fn check_joint_positions(
51        &self,
52        current: &[N],
53        positions: &[N],
54        duration: std::time::Duration,
55    ) -> Result<()> {
56        self.check_partial_joint_positions_inner(None, current, positions, duration)
57    }
58
59    pub fn check_partial_joint_positions(
60        &self,
61        using_joint_names: &[String],
62        current: &[N],
63        positions: &[N],
64        duration: std::time::Duration,
65    ) -> Result<()> {
66        self.check_partial_joint_positions_inner(
67            Some(using_joint_names),
68            current,
69            positions,
70            duration,
71        )
72    }
73
74    fn check_partial_joint_positions_inner(
75        &self,
76        using_joint_names: Option<&[String]>,
77        current: &[N],
78        positions: &[N],
79        duration: std::time::Duration,
80    ) -> Result<()> {
81        let duration_f64 = num_traits::NumCast::from::<f64>(duration.as_secs_f64()).unwrap();
82        match interpolate(
83            &[current.to_vec(), positions.to_vec()],
84            duration_f64,
85            self.time_interpolate_rate.mul(duration_f64),
86        ) {
87            Some(interpolated) => {
88                debug!("interpolated len={}", interpolated.len());
89                self.check_partial_joint_trajectory_inner(using_joint_names, &interpolated)
90            }
91            None => Err(Error::InterpolationError(
92                "failed to interpolate".to_owned(),
93            )),
94        }
95    }
96
97    pub fn check_joint_trajectory(&self, trajectory: &[TrajectoryPoint<N>]) -> Result<()> {
98        self.check_partial_joint_trajectory_inner(None, trajectory)
99    }
100
101    pub fn check_partial_joint_trajectory(
102        &self,
103        using_joint_names: &[String],
104        trajectory: &[TrajectoryPoint<N>],
105    ) -> Result<()> {
106        self.check_partial_joint_trajectory_inner(Some(using_joint_names), trajectory)
107    }
108
109    fn check_partial_joint_trajectory_inner(
110        &self,
111        using_joint_names: Option<&[String]>,
112        trajectory: &[TrajectoryPoint<N>],
113    ) -> Result<()> {
114        // Synchronize with the reference robot states for joints not included using_joints
115        self.collision_check_robot()
116            .set_joint_positions_clamped(self.reference_robot.joint_positions().as_slice());
117
118        let using_joints = match using_joint_names {
119            Some(joint_names) => {
120                create_chain_from_joint_names(self.collision_check_robot(), joint_names).unwrap()
121            }
122            None => {
123                let nodes = self
124                    .collision_check_robot()
125                    .iter()
126                    .map(|node| (*node).clone())
127                    .collect::<Vec<k::Node<N>>>();
128                k::Chain::from_nodes(nodes)
129            }
130        };
131
132        // Check the partial trajectory
133        let last_index = trajectory.len() - 1;
134        for (i, v) in trajectory.iter().enumerate() {
135            using_joints.set_joint_positions(&v.position)?;
136            self.collision_check_robot().update_transforms();
137
138            let mut self_checker = self.robot_collision_detector.detect_self();
139            if let Some(names) = self_checker.next() {
140                return Err(Error::Collision {
141                    point: match i {
142                        0 => UnfeasibleTrajectoryPoint::Start,
143                        index if index == last_index => UnfeasibleTrajectoryPoint::Goal,
144                        _ => UnfeasibleTrajectoryPoint::WayPoint,
145                    },
146                    collision_link_names: vec![names.0, names.1],
147                });
148            }
149
150            // Summarize the calculation time
151            let mut vec_used: Vec<_> = self_checker.used_duration().iter().collect();
152            vec_used.sort_by(|a, b| b.1.cmp(a.1));
153            let sum_duration: Duration = self_checker.used_duration().iter().map(|(_k, v)| v).sum();
154            debug!("total: {sum_duration:?}");
155            debug!("detailed: {vec_used:?}");
156        }
157        Ok(())
158    }
159
160    /// Get the robot model used for collision checking
161    pub fn collision_check_robot(&self) -> &k::Chain<N> {
162        &self.robot_collision_detector.robot
163    }
164}
165
166#[derive(Clone, Serialize, Deserialize, Debug, JsonSchema)]
167#[serde(deny_unknown_fields)]
168pub struct SelfCollisionCheckerConfig {
169    #[serde(default = "default_prediction")]
170    pub prediction: f64,
171    #[serde(default = "default_time_interpolate_rate")]
172    pub time_interpolate_rate: f64,
173}
174
175fn default_prediction() -> f64 {
176    0.001
177}
178
179fn default_time_interpolate_rate() -> f64 {
180    0.5
181}
182
183impl Default for SelfCollisionCheckerConfig {
184    fn default() -> Self {
185        Self {
186            prediction: default_prediction(),
187            time_interpolate_rate: default_time_interpolate_rate(),
188        }
189    }
190}
191
192pub fn create_self_collision_checker<P: AsRef<Path>>(
193    urdf_path: P,
194    self_collision_check_pairs: &[String],
195    config: &SelfCollisionCheckerConfig,
196    robot: Arc<k::Chain<f64>>,
197) -> SelfCollisionChecker<f64> {
198    let collision_detector = CollisionDetector::from_urdf_robot_with_base_dir(
199        &urdf_rs::utils::read_urdf_or_xacro(urdf_path.as_ref()).unwrap(),
200        urdf_path.as_ref().parent(),
201        config.prediction,
202    );
203    let robot_collision_detector = RobotCollisionDetector::new(
204        (*robot).clone(),
205        collision_detector,
206        parse_colon_separated_pairs(self_collision_check_pairs).unwrap(),
207    );
208
209    SelfCollisionChecker::new(
210        robot,
211        robot_collision_detector,
212        config.time_interpolate_rate,
213    )
214}
215
216#[test]
217fn test_create_self_collision_checker() {
218    let urdf_path = Path::new("sample.urdf");
219    let robot = Arc::new(k::Chain::from_urdf_file(urdf_path).unwrap());
220    let self_collision_checker = create_self_collision_checker(
221        urdf_path,
222        &["root:l_shoulder_roll".into()],
223        &SelfCollisionCheckerConfig::default(),
224        robot,
225    );
226
227    assert!(self_collision_checker
228        .check_joint_positions(&[0.0; 16], &[0.0; 16], std::time::Duration::new(1, 0),)
229        .is_ok());
230    assert!(self_collision_checker
231        .check_joint_positions(
232            &[0.0; 16],
233            &[1.57, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
234            std::time::Duration::new(1, 0),
235        )
236        .is_err());
237
238    let l_shoulder_yaw_node = self_collision_checker
239        .collision_check_robot()
240        .find("l_shoulder_yaw")
241        .unwrap();
242    let using_joints = k::SerialChain::from_end(l_shoulder_yaw_node);
243    let using_joint_names = using_joints
244        .iter_joints()
245        .map(|j| j.name.to_owned())
246        .collect::<Vec<String>>();
247
248    assert!(self_collision_checker
249        .check_partial_joint_positions(
250            using_joint_names.as_slice(),
251            &[0.0],
252            &[0.0],
253            std::time::Duration::new(1, 0),
254        )
255        .is_ok());
256    assert!(self_collision_checker
257        .check_partial_joint_positions(
258            using_joint_names.as_slice(),
259            &[0.0],
260            &[1.57],
261            std::time::Duration::new(1, 0),
262        )
263        .is_err());
264}