1use std::time::Duration;
2
3use crate::{error::Error, traits::JointTrajectoryClient, TrajectoryPoint, WaitFuture};
4
5const ZERO_VELOCITY_THRESHOLD: f64 = 1.0e-6;
6
7#[derive(Debug)]
14pub struct JointPositionDifferenceLimiter<C>
15where
16 C: JointTrajectoryClient,
17{
18 client: C,
19 position_difference_limits: Vec<f64>,
20}
21
22impl<C> JointPositionDifferenceLimiter<C>
23where
24 C: JointTrajectoryClient,
25{
26 pub fn new(client: C, mut position_difference_limits: Vec<f64>) -> Result<Self, Error> {
28 if client.joint_names().len() == position_difference_limits.len() {
29 let mut is_valid = true;
30 position_difference_limits.iter_mut().for_each(|f| {
31 let f_abs = f.abs();
32 if f_abs < f64::MIN_POSITIVE {
33 is_valid = false;
34 }
35 *f = f_abs
36 });
37 if !is_valid {
38 return Err(Error::Other(anyhow::format_err!(
39 "Too small position difference limit",
40 )));
41 }
42 Ok(Self {
43 client,
44 position_difference_limits,
45 })
46 } else {
47 Err(Error::LengthMismatch {
48 model: client.joint_names().len(),
49 input: position_difference_limits.len(),
50 })
51 }
52 }
53}
54
55impl<C> JointTrajectoryClient for JointPositionDifferenceLimiter<C>
56where
57 C: JointTrajectoryClient,
58{
59 fn joint_names(&self) -> Vec<String> {
60 self.client.joint_names()
61 }
62
63 fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
64 self.client.current_joint_positions()
65 }
66
67 fn send_joint_positions(
68 &self,
69 positions: Vec<f64>,
70 duration: Duration,
71 ) -> Result<WaitFuture, Error> {
72 let current = self.client.current_joint_positions()?;
73 if current.len() != positions.len() {
74 return Err(Error::LengthMismatch {
75 model: positions.len(),
76 input: current.len(),
77 });
78 }
79 match interpolate(
80 current,
81 &self.position_difference_limits,
82 &positions,
83 &Duration::from_secs(0),
84 &duration,
85 )? {
86 Some(trajectory) => self.client.send_joint_trajectory(trajectory),
87 None => self.client.send_joint_positions(positions, duration),
88 }
89 }
90
91 fn send_joint_trajectory(&self, trajectory: Vec<TrajectoryPoint>) -> Result<WaitFuture, Error> {
94 let fixed_trajectory = if should_interpolate_joint_trajectory(&trajectory) {
95 interpolate_joint_trajectory(
96 self.client.current_joint_positions()?,
97 &self.position_difference_limits,
98 trajectory,
99 )?
100 } else {
101 trajectory
102 };
103 self.client.send_joint_trajectory(fixed_trajectory)
104 }
105}
106
107fn interpolate(
108 mut current: Vec<f64>,
109 position_difference_limits: &[f64],
110 positions: &[f64],
111 first_time_from_start: &Duration,
112 last_time_from_start: &Duration,
113) -> Result<Option<Vec<TrajectoryPoint>>, Error> {
114 let mut max_diff_step: f64 = 0.0;
115 let mut diff = vec![0.0; current.len()];
116 for (i, p) in current.iter().enumerate() {
117 diff[i] = positions[i] - p;
118 let step = diff[i].abs() / position_difference_limits[i].abs();
119 if step.is_infinite() {
120 return Err(Error::Other(anyhow::format_err!(
121 "Invalid position difference limits {} for joint {i} ",
122 position_difference_limits[i],
123 )));
124 }
125 max_diff_step = max_diff_step.max(step);
126 }
127 let max_diff_step = max_diff_step.ceil() as usize;
128 Ok(if max_diff_step <= 1 {
129 None
130 } else {
131 diff.iter_mut().for_each(|d| *d /= max_diff_step as f64);
132 let step_duration = Duration::from_secs_f64(
133 (last_time_from_start.as_secs_f64() - first_time_from_start.as_secs_f64())
134 / max_diff_step as f64,
135 );
136 let mut trajectory = vec![];
137 for i in 1..max_diff_step {
138 current
139 .iter_mut()
140 .enumerate()
141 .for_each(|(i, c)| *c += diff[i]);
142 trajectory.push(TrajectoryPoint {
143 positions: current.to_owned(),
144 velocities: None,
145 time_from_start: *first_time_from_start + step_duration * i as u32,
146 })
147 }
148 trajectory.push(TrajectoryPoint {
149 positions: positions.to_vec(),
150 velocities: None,
151 time_from_start: *last_time_from_start,
152 });
153 Some(trajectory)
154 })
155}
156
157fn should_interpolate_joint_trajectory(trajectory: &[TrajectoryPoint]) -> bool {
158 if trajectory.is_empty() {
159 return false;
160 };
161 match trajectory.iter().position(|p| p.velocities.is_some()) {
162 Some(first_index_of_valid_velocity) => {
163 let last_index = trajectory.len() - 1;
164 if first_index_of_valid_velocity != last_index {
165 false
166 } else {
167 !trajectory[last_index]
168 .velocities
169 .as_ref()
170 .unwrap()
171 .iter()
172 .any(|x| x.abs() > ZERO_VELOCITY_THRESHOLD)
173 }
174 }
175 None => true,
176 }
177}
178
179fn interpolate_joint_trajectory(
180 current: Vec<f64>,
181 position_difference_limits: &[f64],
182 trajectory: Vec<TrajectoryPoint>,
183) -> Result<Vec<TrajectoryPoint>, Error> {
184 let mut fixed_trajectory = vec![];
185 let mut previous_joint_positions = current;
186 let mut previous_time_from_start = Duration::from_secs(0);
187
188 for p in trajectory {
189 let target = p.positions.clone();
190 let velocity = p.velocities.clone();
191 let time_from_start = p.time_from_start;
192 fixed_trajectory.extend(
193 match interpolate(
194 previous_joint_positions,
195 position_difference_limits,
196 &target,
197 &previous_time_from_start,
198 &time_from_start,
199 )? {
200 Some(mut interpolated) => {
201 interpolated.last_mut().unwrap().velocities = velocity;
202 interpolated
203 }
204 None => vec![p],
205 },
206 );
207
208 previous_joint_positions = target;
209 previous_time_from_start = time_from_start;
210 }
211 Ok(fixed_trajectory)
212}
213
214#[cfg(test)]
215mod test {
216 use std::{sync::Arc, time::Duration};
217
218 use assert_approx_eq::assert_approx_eq;
219
220 use super::{
221 interpolate, interpolate_joint_trajectory, should_interpolate_joint_trajectory,
222 JointPositionDifferenceLimiter, JointTrajectoryClient, TrajectoryPoint,
223 };
224 use crate::DummyJointTrajectoryClient;
225
226 #[test]
227 fn interpolate_no_interpolation() {
228 let interpolated = interpolate(
229 vec![0.0, 1.0],
230 &[1.0, 1.0],
231 &[-1.0, 2.0],
232 &Duration::from_secs(0),
233 &Duration::from_secs(1),
234 );
235 assert!(interpolated.is_ok());
236 assert!(interpolated.unwrap().is_none());
237 assert!(interpolate(
238 vec![0.0, 1.0],
239 &[0.0, 0.0],
240 &[-1.0, 2.0],
241 &Duration::from_secs(0),
242 &Duration::from_secs(1),
243 )
244 .is_err());
245 }
246 #[test]
247 fn interpolate_interpolated() {
248 let interpolated = interpolate(
249 vec![0.0, 1.0],
250 &[1.0, 0.5],
251 &[-1.0, 2.0],
252 &Duration::from_secs(0),
253 &Duration::from_secs(1),
254 );
255 assert!(interpolated.is_ok());
256 let interpolated = interpolated.unwrap();
257 assert!(interpolated.is_some());
258 let interpolated = interpolated.unwrap();
259 assert_eq!(interpolated.len(), 2);
260 assert_eq!(interpolated[0].positions, vec![-0.5, 1.5]);
261 assert!(interpolated[0].velocities.is_none());
262 assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 0.5);
263
264 assert_eq!(interpolated[1].positions, vec![-1.0, 2.0]);
265 assert!(interpolated[1].velocities.is_none());
266 assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 1.0);
267 }
268
269 #[test]
270 fn joint_position_difference_limiter_new_error() {
271 let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
272 "a".to_owned(),
273 "b".to_owned(),
274 ]));
275 assert!(
276 JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![3.0, 1.0, 2.0])
277 .is_err()
278 );
279 assert!(JointPositionDifferenceLimiter::new(wrapped_client, vec![1.0, 0.0]).is_err());
280 }
281 #[test]
282 fn joint_position_difference_limiter_send_joint_trajectory() {
283 let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
284 "a".to_owned(),
285 "b".to_owned(),
286 ]));
287 let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, 2.0]);
288 assert!(client.is_ok());
289 let client = client.unwrap();
290 assert_eq!(
291 client.joint_names().len(),
292 wrapped_client.joint_names().len()
293 );
294 for (c, w) in client
295 .joint_names()
296 .iter()
297 .zip(wrapped_client.joint_names().iter())
298 {
299 assert_eq!(c, w);
300 }
301
302 let trajectory = vec![
303 TrajectoryPoint {
304 positions: vec![1.0, 2.0],
305 velocities: Some(vec![3.0, 4.0]),
306 time_from_start: std::time::Duration::from_secs_f64(4.0),
307 },
308 TrajectoryPoint {
309 positions: vec![3.0, 6.0],
310 velocities: Some(vec![3.0, 4.0]),
311 time_from_start: std::time::Duration::from_secs_f64(8.0),
312 },
313 ];
314 assert!(
315 tokio_test::block_on(client.send_joint_trajectory(trajectory.clone()).unwrap()).is_ok()
316 );
317 for (c, w) in trajectory.iter().zip(
318 wrapped_client
319 .last_trajectory
320 .lock()
321 .unwrap()
322 .clone()
323 .iter(),
324 ) {
325 assert_eq!(c.positions, w.positions);
326 assert_eq!(c.velocities, w.velocities);
327 assert_eq!(c.time_from_start, w.time_from_start);
328 }
329 assert_eq!(
330 trajectory.last().unwrap().positions,
331 client.current_joint_positions().unwrap()
332 );
333 }
334 #[test]
335 fn joint_position_difference_limiter_send_joint_positions_no_interpolation() {
336 let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
337 "a".to_owned(),
338 "b".to_owned(),
339 ]));
340 let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, 1.0]);
341 assert!(client.is_ok());
342 let client = client.unwrap();
343 assert_eq!(
344 client.joint_names().len(),
345 wrapped_client.joint_names().len()
346 );
347 for (c, w) in client
348 .joint_names()
349 .iter()
350 .zip(wrapped_client.joint_names().iter())
351 {
352 assert_eq!(c, w);
353 }
354 *wrapped_client.positions.lock().unwrap() = vec![0.0, 1.0];
355 assert!(tokio_test::block_on(
356 client
357 .send_joint_positions(vec![-1.0, 2.0], Duration::from_secs(1))
358 .unwrap()
359 )
360 .is_ok());
361 assert!(wrapped_client.last_trajectory.lock().unwrap().is_empty());
362 assert_eq!(
363 wrapped_client.current_joint_positions().unwrap(),
364 vec![-1.0, 2.0]
365 );
366 }
367 #[test]
368 fn joint_position_difference_limiter_send_joint_positions_interpolated() {
369 let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
370 "a".to_owned(),
371 "b".to_owned(),
372 ]));
373 let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, -0.5]);
374 assert!(client.is_ok());
375 let client = client.unwrap();
376 assert_eq!(
377 client.joint_names().len(),
378 wrapped_client.joint_names().len()
379 );
380 for (c, w) in client
381 .joint_names()
382 .iter()
383 .zip(wrapped_client.joint_names().iter())
384 {
385 assert_eq!(c, w);
386 }
387 *wrapped_client.positions.lock().unwrap() = vec![0.0, 1.0];
388 assert!(tokio_test::block_on(
389 client
390 .send_joint_positions(vec![-1.0, 2.0], Duration::from_secs(1))
391 .unwrap()
392 )
393 .is_ok());
394 let actual_trajectory = wrapped_client.last_trajectory.lock().unwrap().clone();
395 assert_eq!(actual_trajectory.len(), 2);
396 assert_eq!(actual_trajectory[0].positions, vec![-0.5, 1.5]);
397 assert_eq!(actual_trajectory[1].positions, vec![-1.0, 2.0]);
398 assert!(actual_trajectory[0].velocities.is_none());
399 assert!(actual_trajectory[1].velocities.is_none());
400 assert_approx_eq!(actual_trajectory[0].time_from_start.as_secs_f64(), 0.5);
401 assert_approx_eq!(actual_trajectory[1].time_from_start.as_secs_f64(), 1.0);
402
403 assert_eq!(
404 wrapped_client.current_joint_positions().unwrap(),
405 vec![-1.0, 2.0]
406 );
407 }
408
409 #[test]
410 fn test_should_interpolate_joint_trajectory() {
411 assert!(!should_interpolate_joint_trajectory(&[]));
412 assert!(!should_interpolate_joint_trajectory(&[
413 TrajectoryPoint {
414 positions: vec![],
415 velocities: Some(vec![0.0, 0.0]),
416 time_from_start: std::time::Duration::from_secs(0),
417 },
418 TrajectoryPoint {
419 positions: vec![],
420 velocities: Some(vec![0.0, 0.0]),
421 time_from_start: std::time::Duration::from_secs(0),
422 }
423 ]));
424 assert!(!should_interpolate_joint_trajectory(&[
425 TrajectoryPoint {
426 positions: vec![],
427 velocities: None,
428 time_from_start: std::time::Duration::from_secs(0),
429 },
430 TrajectoryPoint {
431 positions: vec![],
432 velocities: Some(vec![0.0, 0.01]),
433 time_from_start: std::time::Duration::from_secs(0),
434 }
435 ]));
436 assert!(should_interpolate_joint_trajectory(&[
437 TrajectoryPoint {
438 positions: vec![],
439 velocities: None,
440 time_from_start: std::time::Duration::from_secs(0),
441 },
442 TrajectoryPoint {
443 positions: vec![],
444 velocities: Some(vec![0.0, 0.0]),
445 time_from_start: std::time::Duration::from_secs(0),
446 }
447 ]));
448 }
449
450 #[test]
451 fn test_should_interpolate_joint_trajectory_no_interpolation() {
452 let interpolated = interpolate_joint_trajectory(
453 vec![0.0, 1.0],
454 &[1.0, 1.0],
455 vec![
456 TrajectoryPoint {
457 positions: vec![-1.0, 2.0],
458 velocities: None,
459 time_from_start: std::time::Duration::from_secs(1),
460 },
461 TrajectoryPoint {
462 positions: vec![-2.0, 3.0],
463 velocities: Some(vec![0.0, 0.0]),
464 time_from_start: std::time::Duration::from_secs(2),
465 },
466 ],
467 );
468 assert!(interpolated.is_ok());
469 let interpolated = interpolated.unwrap();
470
471 assert_eq!(interpolated.len(), 2);
472
473 assert_eq!(interpolated[0].positions, vec![-1.0, 2.0]);
474 assert!(interpolated[0].velocities.is_none());
475 assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 1.0);
476
477 assert_eq!(interpolated[1].positions, vec![-2.0, 3.0]);
478 assert!(interpolated[1].velocities.is_some());
479 assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 2.0);
480 }
481
482 #[test]
483 fn test_should_interpolate_joint_trajectory_interpolated() {
484 let interpolated = interpolate_joint_trajectory(
485 vec![0.0, 1.0],
486 &[1.0, 0.5],
487 vec![
488 TrajectoryPoint {
489 positions: vec![-1.0, 2.0],
490 velocities: None,
491 time_from_start: std::time::Duration::from_secs(1),
492 },
493 TrajectoryPoint {
494 positions: vec![-2.0, 3.0],
495 velocities: Some(vec![0.0, 0.0]),
496 time_from_start: std::time::Duration::from_secs(2),
497 },
498 ],
499 );
500 assert!(interpolated.is_ok());
501 let interpolated = interpolated.unwrap();
502
503 assert_eq!(interpolated.len(), 4);
504
505 assert_eq!(interpolated[0].positions, vec![-0.5, 1.5]);
506 assert!(interpolated[0].velocities.is_none());
507 assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 0.5);
508
509 assert_eq!(interpolated[1].positions, vec![-1.0, 2.0]);
510 assert!(interpolated[1].velocities.is_none());
511 assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 1.0);
512
513 assert_eq!(interpolated[2].positions, vec![-1.5, 2.5]);
514 assert!(interpolated[2].velocities.is_none());
515 assert_approx_eq!(interpolated[2].time_from_start.as_secs_f64(), 1.5);
516
517 assert_eq!(interpolated[3].positions, vec![-2.0, 3.0]);
518 assert!(interpolated[3].velocities.is_some());
519 assert_approx_eq!(interpolated[3].time_from_start.as_secs_f64(), 2.0);
520 }
521}