openrr_planner/collision/
self_collision_checker.rs1use std::{path::Path, sync::Arc, time::Duration};
2
3use k::nalgebra as na;
4use na::RealField;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use tracing::debug;
8
9use crate::{
10 CollisionDetector, TrajectoryPoint,
11 collision::{RobotCollisionDetector, parse_colon_separated_pairs},
12 errors::*,
13 funcs::create_chain_from_joint_names,
14 interpolate,
15};
16
17pub struct SelfCollisionChecker<N>
18where
19 N: RealField + Copy + k::SubsetOf<f64>,
20{
21 reference_robot: Arc<k::Chain<N>>,
23 robot_collision_detector: RobotCollisionDetector<N>,
25 time_interpolate_rate: N,
27}
28
29impl<N> SelfCollisionChecker<N>
30where
31 N: RealField + k::SubsetOf<f64> + num_traits::Float,
32{
33 #[track_caller]
34 pub fn new(
35 reference_robot: Arc<k::Chain<N>>,
36 robot_collision_detector: RobotCollisionDetector<N>,
37 time_interpolate_rate: N,
38 ) -> Self {
39 assert!(
40 time_interpolate_rate > na::convert(0.0) && time_interpolate_rate <= na::convert(1.0),
41 "time_interpolate_rate must be 0.0~1.0 but {time_interpolate_rate}",
42 );
43
44 Self {
45 reference_robot,
46 robot_collision_detector,
47 time_interpolate_rate,
48 }
49 }
50
51 pub fn check_joint_positions(
52 &self,
53 current: &[N],
54 positions: &[N],
55 duration: std::time::Duration,
56 ) -> Result<()> {
57 self.check_partial_joint_positions_inner(None, current, positions, duration)
58 }
59
60 pub fn check_partial_joint_positions(
61 &self,
62 using_joint_names: &[String],
63 current: &[N],
64 positions: &[N],
65 duration: std::time::Duration,
66 ) -> Result<()> {
67 self.check_partial_joint_positions_inner(
68 Some(using_joint_names),
69 current,
70 positions,
71 duration,
72 )
73 }
74
75 fn check_partial_joint_positions_inner(
76 &self,
77 using_joint_names: Option<&[String]>,
78 current: &[N],
79 positions: &[N],
80 duration: std::time::Duration,
81 ) -> Result<()> {
82 let duration_f64 = num_traits::NumCast::from::<f64>(duration.as_secs_f64()).unwrap();
83 match interpolate(
84 &[current.to_vec(), positions.to_vec()],
85 duration_f64,
86 self.time_interpolate_rate.mul(duration_f64),
87 ) {
88 Some(interpolated) => {
89 debug!("interpolated len={}", interpolated.len());
90 self.check_partial_joint_trajectory_inner(using_joint_names, &interpolated)
91 }
92 None => Err(Error::InterpolationError(
93 "failed to interpolate".to_owned(),
94 )),
95 }
96 }
97
98 pub fn check_joint_trajectory(&self, trajectory: &[TrajectoryPoint<N>]) -> Result<()> {
99 self.check_partial_joint_trajectory_inner(None, trajectory)
100 }
101
102 pub fn check_partial_joint_trajectory(
103 &self,
104 using_joint_names: &[String],
105 trajectory: &[TrajectoryPoint<N>],
106 ) -> Result<()> {
107 self.check_partial_joint_trajectory_inner(Some(using_joint_names), trajectory)
108 }
109
110 fn check_partial_joint_trajectory_inner(
111 &self,
112 using_joint_names: Option<&[String]>,
113 trajectory: &[TrajectoryPoint<N>],
114 ) -> Result<()> {
115 self.collision_check_robot()
117 .set_joint_positions_clamped(self.reference_robot.joint_positions().as_slice());
118
119 let using_joints = match using_joint_names {
120 Some(joint_names) => {
121 create_chain_from_joint_names(self.collision_check_robot(), joint_names).unwrap()
122 }
123 None => {
124 let nodes = self
125 .collision_check_robot()
126 .iter()
127 .map(|node| (*node).clone())
128 .collect::<Vec<k::Node<N>>>();
129 k::Chain::from_nodes(nodes)
130 }
131 };
132
133 let last_index = trajectory.len() - 1;
135 for (i, v) in trajectory.iter().enumerate() {
136 using_joints.set_joint_positions(&v.position)?;
137 self.collision_check_robot().update_transforms();
138
139 let mut self_checker = self.robot_collision_detector.detect_self();
140 if let Some(names) = self_checker.next() {
141 return Err(Error::Collision {
142 point: match i {
143 0 => UnfeasibleTrajectoryPoint::Start,
144 index if index == last_index => UnfeasibleTrajectoryPoint::Goal,
145 _ => UnfeasibleTrajectoryPoint::WayPoint,
146 },
147 collision_link_names: vec![names.0, names.1],
148 });
149 }
150
151 let mut vec_used: Vec<_> = self_checker.used_duration().iter().collect();
153 vec_used.sort_by(|a, b| b.1.cmp(a.1));
154 let sum_duration: Duration = self_checker.used_duration().values().sum();
155 debug!("total: {sum_duration:?}");
156 debug!("detailed: {vec_used:?}");
157 }
158 Ok(())
159 }
160
161 pub fn collision_check_robot(&self) -> &k::Chain<N> {
163 &self.robot_collision_detector.robot
164 }
165}
166
167#[derive(Clone, Serialize, Deserialize, Debug, JsonSchema)]
168#[serde(deny_unknown_fields)]
169pub struct SelfCollisionCheckerConfig {
170 #[serde(default = "default_prediction")]
171 pub prediction: f64,
172 #[serde(default = "default_time_interpolate_rate")]
173 pub time_interpolate_rate: f64,
174}
175
176fn default_prediction() -> f64 {
177 0.001
178}
179
180fn default_time_interpolate_rate() -> f64 {
181 0.5
182}
183
184impl Default for SelfCollisionCheckerConfig {
185 fn default() -> Self {
186 Self {
187 prediction: default_prediction(),
188 time_interpolate_rate: default_time_interpolate_rate(),
189 }
190 }
191}
192
193pub fn create_self_collision_checker<P: AsRef<Path>>(
194 urdf_path: P,
195 self_collision_check_pairs: &[String],
196 config: &SelfCollisionCheckerConfig,
197 robot: Arc<k::Chain<f64>>,
198) -> SelfCollisionChecker<f64> {
199 let collision_detector = CollisionDetector::from_urdf_robot_with_base_dir(
200 &urdf_rs::utils::read_urdf_or_xacro(urdf_path.as_ref()).unwrap(),
201 urdf_path.as_ref().parent(),
202 config.prediction,
203 );
204 let robot_collision_detector = RobotCollisionDetector::new(
205 (*robot).clone(),
206 collision_detector,
207 parse_colon_separated_pairs(self_collision_check_pairs).unwrap(),
208 );
209
210 SelfCollisionChecker::new(
211 robot,
212 robot_collision_detector,
213 config.time_interpolate_rate,
214 )
215}
216
217#[test]
218fn test_create_self_collision_checker() {
219 let urdf_path = Path::new("sample.urdf");
220 let robot = Arc::new(k::Chain::from_urdf_file(urdf_path).unwrap());
221 let self_collision_checker = create_self_collision_checker(
222 urdf_path,
223 &["root:l_shoulder_roll".into()],
224 &SelfCollisionCheckerConfig::default(),
225 robot,
226 );
227
228 assert!(
229 self_collision_checker
230 .check_joint_positions(&[0.0; 16], &[0.0; 16], std::time::Duration::new(1, 0),)
231 .is_ok()
232 );
233 assert!(
234 self_collision_checker
235 .check_joint_positions(
236 &[0.0; 16],
237 &[
238 1.57, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
239 ],
240 std::time::Duration::new(1, 0),
241 )
242 .is_err()
243 );
244
245 let l_shoulder_yaw_node = self_collision_checker
246 .collision_check_robot()
247 .find("l_shoulder_yaw")
248 .unwrap();
249 let using_joints = k::SerialChain::from_end(l_shoulder_yaw_node);
250 let using_joint_names = using_joints
251 .iter_joints()
252 .map(|j| j.name.to_owned())
253 .collect::<Vec<String>>();
254
255 assert!(
256 self_collision_checker
257 .check_partial_joint_positions(
258 using_joint_names.as_slice(),
259 &[0.0],
260 &[0.0],
261 std::time::Duration::new(1, 0),
262 )
263 .is_ok()
264 );
265 assert!(
266 self_collision_checker
267 .check_partial_joint_positions(
268 using_joint_names.as_slice(),
269 &[0.0],
270 &[1.57],
271 std::time::Duration::new(1, 0),
272 )
273 .is_err()
274 );
275}