1use std::time::Duration;
2
3use crate::{TrajectoryPoint, WaitFuture, error::Error, traits::JointTrajectoryClient};
4
5const ZERO_VELOCITY_THRESHOLD: f64 = 1.0e-6;
6
7#[derive(Debug)]
14pub struct JointPositionDifferenceLimiter<C>
15where
16 C: JointTrajectoryClient,
17{
18 client: C,
19 position_difference_limits: Vec<f64>,
20}
21
22impl<C> JointPositionDifferenceLimiter<C>
23where
24 C: JointTrajectoryClient,
25{
26 pub fn new(client: C, mut position_difference_limits: Vec<f64>) -> Result<Self, Error> {
28 if client.joint_names().len() == position_difference_limits.len() {
29 let mut is_valid = true;
30 position_difference_limits.iter_mut().for_each(|f| {
31 let f_abs = f.abs();
32 if f_abs < f64::MIN_POSITIVE {
33 is_valid = false;
34 }
35 *f = f_abs
36 });
37 if !is_valid {
38 return Err(Error::Other(anyhow::format_err!(
39 "Too small position difference limit",
40 )));
41 }
42 Ok(Self {
43 client,
44 position_difference_limits,
45 })
46 } else {
47 Err(Error::LengthMismatch {
48 model: client.joint_names().len(),
49 input: position_difference_limits.len(),
50 })
51 }
52 }
53}
54
55impl<C> JointTrajectoryClient for JointPositionDifferenceLimiter<C>
56where
57 C: JointTrajectoryClient,
58{
59 fn joint_names(&self) -> Vec<String> {
60 self.client.joint_names()
61 }
62
63 fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
64 self.client.current_joint_positions()
65 }
66
67 fn send_joint_positions(
68 &self,
69 positions: Vec<f64>,
70 duration: Duration,
71 ) -> Result<WaitFuture, Error> {
72 let current = self.client.current_joint_positions()?;
73 if current.len() != positions.len() {
74 return Err(Error::LengthMismatch {
75 model: positions.len(),
76 input: current.len(),
77 });
78 }
79 match interpolate(
80 current,
81 &self.position_difference_limits,
82 &positions,
83 &Duration::from_secs(0),
84 &duration,
85 )? {
86 Some(trajectory) => self.client.send_joint_trajectory(trajectory),
87 None => self.client.send_joint_positions(positions, duration),
88 }
89 }
90
91 fn send_joint_trajectory(&self, trajectory: Vec<TrajectoryPoint>) -> Result<WaitFuture, Error> {
94 let fixed_trajectory = if should_interpolate_joint_trajectory(&trajectory) {
95 interpolate_joint_trajectory(
96 self.client.current_joint_positions()?,
97 &self.position_difference_limits,
98 trajectory,
99 )?
100 } else {
101 trajectory
102 };
103 self.client.send_joint_trajectory(fixed_trajectory)
104 }
105}
106
107fn interpolate(
108 mut current: Vec<f64>,
109 position_difference_limits: &[f64],
110 positions: &[f64],
111 first_time_from_start: &Duration,
112 last_time_from_start: &Duration,
113) -> Result<Option<Vec<TrajectoryPoint>>, Error> {
114 let mut max_diff_step: f64 = 0.0;
115 let mut diff = vec![0.0; current.len()];
116 for (i, p) in current.iter().enumerate() {
117 diff[i] = positions[i] - p;
118 let step = diff[i].abs() / position_difference_limits[i].abs();
119 if step.is_infinite() {
120 return Err(Error::Other(anyhow::format_err!(
121 "Invalid position difference limits {} for joint {i} ",
122 position_difference_limits[i],
123 )));
124 }
125 max_diff_step = max_diff_step.max(step);
126 }
127 let max_diff_step = max_diff_step.ceil() as usize;
128 Ok(if max_diff_step <= 1 {
129 None
130 } else {
131 diff.iter_mut().for_each(|d| *d /= max_diff_step as f64);
132 let step_duration = Duration::from_secs_f64(
133 (last_time_from_start.as_secs_f64() - first_time_from_start.as_secs_f64())
134 / max_diff_step as f64,
135 );
136 let mut trajectory = vec![];
137 for i in 1..max_diff_step {
138 current
139 .iter_mut()
140 .enumerate()
141 .for_each(|(i, c)| *c += diff[i]);
142 trajectory.push(TrajectoryPoint {
143 positions: current.to_owned(),
144 velocities: None,
145 time_from_start: *first_time_from_start + step_duration * i as u32,
146 })
147 }
148 trajectory.push(TrajectoryPoint {
149 positions: positions.to_vec(),
150 velocities: None,
151 time_from_start: *last_time_from_start,
152 });
153 Some(trajectory)
154 })
155}
156
157fn should_interpolate_joint_trajectory(trajectory: &[TrajectoryPoint]) -> bool {
158 if trajectory.is_empty() {
159 return false;
160 };
161 match trajectory.iter().position(|p| p.velocities.is_some()) {
162 Some(first_index_of_valid_velocity) => {
163 let last_index = trajectory.len() - 1;
164 if first_index_of_valid_velocity != last_index {
165 false
166 } else {
167 !trajectory[last_index]
168 .velocities
169 .as_ref()
170 .unwrap()
171 .iter()
172 .any(|x| x.abs() > ZERO_VELOCITY_THRESHOLD)
173 }
174 }
175 None => true,
176 }
177}
178
179fn interpolate_joint_trajectory(
180 current: Vec<f64>,
181 position_difference_limits: &[f64],
182 trajectory: Vec<TrajectoryPoint>,
183) -> Result<Vec<TrajectoryPoint>, Error> {
184 let mut fixed_trajectory = vec![];
185 let mut previous_joint_positions = current;
186 let mut previous_time_from_start = Duration::from_secs(0);
187
188 for p in trajectory {
189 let target = p.positions.clone();
190 let velocity = p.velocities.clone();
191 let time_from_start = p.time_from_start;
192 fixed_trajectory.extend(
193 match interpolate(
194 previous_joint_positions,
195 position_difference_limits,
196 &target,
197 &previous_time_from_start,
198 &time_from_start,
199 )? {
200 Some(mut interpolated) => {
201 interpolated.last_mut().unwrap().velocities = velocity;
202 interpolated
203 }
204 None => vec![p],
205 },
206 );
207
208 previous_joint_positions = target;
209 previous_time_from_start = time_from_start;
210 }
211 Ok(fixed_trajectory)
212}
213
214#[cfg(test)]
215mod test {
216 use std::{sync::Arc, time::Duration};
217
218 use assert_approx_eq::assert_approx_eq;
219
220 use super::{
221 JointPositionDifferenceLimiter, JointTrajectoryClient, TrajectoryPoint, interpolate,
222 interpolate_joint_trajectory, should_interpolate_joint_trajectory,
223 };
224 use crate::DummyJointTrajectoryClient;
225
226 #[test]
227 fn interpolate_no_interpolation() {
228 let interpolated = interpolate(
229 vec![0.0, 1.0],
230 &[1.0, 1.0],
231 &[-1.0, 2.0],
232 &Duration::from_secs(0),
233 &Duration::from_secs(1),
234 );
235 assert!(interpolated.is_ok());
236 assert!(interpolated.unwrap().is_none());
237 assert!(
238 interpolate(
239 vec![0.0, 1.0],
240 &[0.0, 0.0],
241 &[-1.0, 2.0],
242 &Duration::from_secs(0),
243 &Duration::from_secs(1),
244 )
245 .is_err()
246 );
247 }
248 #[test]
249 fn interpolate_interpolated() {
250 let interpolated = interpolate(
251 vec![0.0, 1.0],
252 &[1.0, 0.5],
253 &[-1.0, 2.0],
254 &Duration::from_secs(0),
255 &Duration::from_secs(1),
256 );
257 assert!(interpolated.is_ok());
258 let interpolated = interpolated.unwrap();
259 assert!(interpolated.is_some());
260 let interpolated = interpolated.unwrap();
261 assert_eq!(interpolated.len(), 2);
262 assert_eq!(interpolated[0].positions, vec![-0.5, 1.5]);
263 assert!(interpolated[0].velocities.is_none());
264 assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 0.5);
265
266 assert_eq!(interpolated[1].positions, vec![-1.0, 2.0]);
267 assert!(interpolated[1].velocities.is_none());
268 assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 1.0);
269 }
270
271 #[test]
272 fn joint_position_difference_limiter_new_error() {
273 let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
274 "a".to_owned(),
275 "b".to_owned(),
276 ]));
277 assert!(
278 JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![3.0, 1.0, 2.0])
279 .is_err()
280 );
281 assert!(JointPositionDifferenceLimiter::new(wrapped_client, vec![1.0, 0.0]).is_err());
282 }
283 #[test]
284 fn joint_position_difference_limiter_send_joint_trajectory() {
285 let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
286 "a".to_owned(),
287 "b".to_owned(),
288 ]));
289 let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, 2.0]);
290 assert!(client.is_ok());
291 let client = client.unwrap();
292 assert_eq!(
293 client.joint_names().len(),
294 wrapped_client.joint_names().len()
295 );
296 for (c, w) in client
297 .joint_names()
298 .iter()
299 .zip(wrapped_client.joint_names().iter())
300 {
301 assert_eq!(c, w);
302 }
303
304 let trajectory = vec![
305 TrajectoryPoint {
306 positions: vec![1.0, 2.0],
307 velocities: Some(vec![3.0, 4.0]),
308 time_from_start: std::time::Duration::from_secs_f64(4.0),
309 },
310 TrajectoryPoint {
311 positions: vec![3.0, 6.0],
312 velocities: Some(vec![3.0, 4.0]),
313 time_from_start: std::time::Duration::from_secs_f64(8.0),
314 },
315 ];
316 assert!(
317 tokio_test::block_on(client.send_joint_trajectory(trajectory.clone()).unwrap()).is_ok()
318 );
319 for (c, w) in trajectory.iter().zip(
320 wrapped_client
321 .last_trajectory
322 .lock()
323 .unwrap()
324 .clone()
325 .iter(),
326 ) {
327 assert_eq!(c.positions, w.positions);
328 assert_eq!(c.velocities, w.velocities);
329 assert_eq!(c.time_from_start, w.time_from_start);
330 }
331 assert_eq!(
332 trajectory.last().unwrap().positions,
333 client.current_joint_positions().unwrap()
334 );
335 }
336 #[test]
337 fn joint_position_difference_limiter_send_joint_positions_no_interpolation() {
338 let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
339 "a".to_owned(),
340 "b".to_owned(),
341 ]));
342 let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, 1.0]);
343 assert!(client.is_ok());
344 let client = client.unwrap();
345 assert_eq!(
346 client.joint_names().len(),
347 wrapped_client.joint_names().len()
348 );
349 for (c, w) in client
350 .joint_names()
351 .iter()
352 .zip(wrapped_client.joint_names().iter())
353 {
354 assert_eq!(c, w);
355 }
356 *wrapped_client.positions.lock().unwrap() = vec![0.0, 1.0];
357 assert!(
358 tokio_test::block_on(
359 client
360 .send_joint_positions(vec![-1.0, 2.0], Duration::from_secs(1))
361 .unwrap()
362 )
363 .is_ok()
364 );
365 assert!(wrapped_client.last_trajectory.lock().unwrap().is_empty());
366 assert_eq!(
367 wrapped_client.current_joint_positions().unwrap(),
368 vec![-1.0, 2.0]
369 );
370 }
371 #[test]
372 fn joint_position_difference_limiter_send_joint_positions_interpolated() {
373 let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
374 "a".to_owned(),
375 "b".to_owned(),
376 ]));
377 let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, -0.5]);
378 assert!(client.is_ok());
379 let client = client.unwrap();
380 assert_eq!(
381 client.joint_names().len(),
382 wrapped_client.joint_names().len()
383 );
384 for (c, w) in client
385 .joint_names()
386 .iter()
387 .zip(wrapped_client.joint_names().iter())
388 {
389 assert_eq!(c, w);
390 }
391 *wrapped_client.positions.lock().unwrap() = vec![0.0, 1.0];
392 assert!(
393 tokio_test::block_on(
394 client
395 .send_joint_positions(vec![-1.0, 2.0], Duration::from_secs(1))
396 .unwrap()
397 )
398 .is_ok()
399 );
400 let actual_trajectory = wrapped_client.last_trajectory.lock().unwrap().clone();
401 assert_eq!(actual_trajectory.len(), 2);
402 assert_eq!(actual_trajectory[0].positions, vec![-0.5, 1.5]);
403 assert_eq!(actual_trajectory[1].positions, vec![-1.0, 2.0]);
404 assert!(actual_trajectory[0].velocities.is_none());
405 assert!(actual_trajectory[1].velocities.is_none());
406 assert_approx_eq!(actual_trajectory[0].time_from_start.as_secs_f64(), 0.5);
407 assert_approx_eq!(actual_trajectory[1].time_from_start.as_secs_f64(), 1.0);
408
409 assert_eq!(
410 wrapped_client.current_joint_positions().unwrap(),
411 vec![-1.0, 2.0]
412 );
413 }
414
415 #[test]
416 fn test_should_interpolate_joint_trajectory() {
417 assert!(!should_interpolate_joint_trajectory(&[]));
418 assert!(!should_interpolate_joint_trajectory(&[
419 TrajectoryPoint {
420 positions: vec![],
421 velocities: Some(vec![0.0, 0.0]),
422 time_from_start: std::time::Duration::from_secs(0),
423 },
424 TrajectoryPoint {
425 positions: vec![],
426 velocities: Some(vec![0.0, 0.0]),
427 time_from_start: std::time::Duration::from_secs(0),
428 }
429 ]));
430 assert!(!should_interpolate_joint_trajectory(&[
431 TrajectoryPoint {
432 positions: vec![],
433 velocities: None,
434 time_from_start: std::time::Duration::from_secs(0),
435 },
436 TrajectoryPoint {
437 positions: vec![],
438 velocities: Some(vec![0.0, 0.01]),
439 time_from_start: std::time::Duration::from_secs(0),
440 }
441 ]));
442 assert!(should_interpolate_joint_trajectory(&[
443 TrajectoryPoint {
444 positions: vec![],
445 velocities: None,
446 time_from_start: std::time::Duration::from_secs(0),
447 },
448 TrajectoryPoint {
449 positions: vec![],
450 velocities: Some(vec![0.0, 0.0]),
451 time_from_start: std::time::Duration::from_secs(0),
452 }
453 ]));
454 }
455
456 #[test]
457 fn test_should_interpolate_joint_trajectory_no_interpolation() {
458 let interpolated = interpolate_joint_trajectory(
459 vec![0.0, 1.0],
460 &[1.0, 1.0],
461 vec![
462 TrajectoryPoint {
463 positions: vec![-1.0, 2.0],
464 velocities: None,
465 time_from_start: std::time::Duration::from_secs(1),
466 },
467 TrajectoryPoint {
468 positions: vec![-2.0, 3.0],
469 velocities: Some(vec![0.0, 0.0]),
470 time_from_start: std::time::Duration::from_secs(2),
471 },
472 ],
473 );
474 assert!(interpolated.is_ok());
475 let interpolated = interpolated.unwrap();
476
477 assert_eq!(interpolated.len(), 2);
478
479 assert_eq!(interpolated[0].positions, vec![-1.0, 2.0]);
480 assert!(interpolated[0].velocities.is_none());
481 assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 1.0);
482
483 assert_eq!(interpolated[1].positions, vec![-2.0, 3.0]);
484 assert!(interpolated[1].velocities.is_some());
485 assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 2.0);
486 }
487
488 #[test]
489 fn test_should_interpolate_joint_trajectory_interpolated() {
490 let interpolated = interpolate_joint_trajectory(
491 vec![0.0, 1.0],
492 &[1.0, 0.5],
493 vec![
494 TrajectoryPoint {
495 positions: vec![-1.0, 2.0],
496 velocities: None,
497 time_from_start: std::time::Duration::from_secs(1),
498 },
499 TrajectoryPoint {
500 positions: vec![-2.0, 3.0],
501 velocities: Some(vec![0.0, 0.0]),
502 time_from_start: std::time::Duration::from_secs(2),
503 },
504 ],
505 );
506 assert!(interpolated.is_ok());
507 let interpolated = interpolated.unwrap();
508
509 assert_eq!(interpolated.len(), 4);
510
511 assert_eq!(interpolated[0].positions, vec![-0.5, 1.5]);
512 assert!(interpolated[0].velocities.is_none());
513 assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 0.5);
514
515 assert_eq!(interpolated[1].positions, vec![-1.0, 2.0]);
516 assert!(interpolated[1].velocities.is_none());
517 assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 1.0);
518
519 assert_eq!(interpolated[2].positions, vec![-1.5, 2.5]);
520 assert!(interpolated[2].velocities.is_none());
521 assert_approx_eq!(interpolated[2].time_from_start.as_secs_f64(), 1.5);
522
523 assert_eq!(interpolated[3].positions, vec![-2.0, 3.0]);
524 assert!(interpolated[3].velocities.is_some());
525 assert_approx_eq!(interpolated[3].time_from_start.as_secs_f64(), 2.0);
526 }
527}