arci/clients/
joint_position_difference_limiter.rs

1use std::time::Duration;
2
3use crate::{TrajectoryPoint, WaitFuture, error::Error, traits::JointTrajectoryClient};
4
5const ZERO_VELOCITY_THRESHOLD: f64 = 1.0e-6;
6
7/// JointPositionDifferenceLimiter limits the difference of position between trajectory points and
8///  trajectory points are interpolated linearly to satisfy the limits
9///  in JointTrajectoryClient::send_joint_positions.
10/// In send_joint_trajectory, if no velocities is specified or zero velocities is specified at the
11///  last point, trajectory points are interpolated, otherwise simply input trajectory is forwarded to client.
12
13#[derive(Debug)]
14pub struct JointPositionDifferenceLimiter<C>
15where
16    C: JointTrajectoryClient,
17{
18    client: C,
19    position_difference_limits: Vec<f64>,
20}
21
22impl<C> JointPositionDifferenceLimiter<C>
23where
24    C: JointTrajectoryClient,
25{
26    /// Create a new `JointPositionDifferenceLimiter` with the given position difference limits.
27    pub fn new(client: C, mut position_difference_limits: Vec<f64>) -> Result<Self, Error> {
28        if client.joint_names().len() == position_difference_limits.len() {
29            let mut is_valid = true;
30            position_difference_limits.iter_mut().for_each(|f| {
31                let f_abs = f.abs();
32                if f_abs < f64::MIN_POSITIVE {
33                    is_valid = false;
34                }
35                *f = f_abs
36            });
37            if !is_valid {
38                return Err(Error::Other(anyhow::format_err!(
39                    "Too small position difference limit",
40                )));
41            }
42            Ok(Self {
43                client,
44                position_difference_limits,
45            })
46        } else {
47            Err(Error::LengthMismatch {
48                model: client.joint_names().len(),
49                input: position_difference_limits.len(),
50            })
51        }
52    }
53}
54
55impl<C> JointTrajectoryClient for JointPositionDifferenceLimiter<C>
56where
57    C: JointTrajectoryClient,
58{
59    fn joint_names(&self) -> Vec<String> {
60        self.client.joint_names()
61    }
62
63    fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
64        self.client.current_joint_positions()
65    }
66
67    fn send_joint_positions(
68        &self,
69        positions: Vec<f64>,
70        duration: Duration,
71    ) -> Result<WaitFuture, Error> {
72        let current = self.client.current_joint_positions()?;
73        if current.len() != positions.len() {
74            return Err(Error::LengthMismatch {
75                model: positions.len(),
76                input: current.len(),
77            });
78        }
79        match interpolate(
80            current,
81            &self.position_difference_limits,
82            &positions,
83            &Duration::from_secs(0),
84            &duration,
85        )? {
86            Some(trajectory) => self.client.send_joint_trajectory(trajectory),
87            None => self.client.send_joint_positions(positions, duration),
88        }
89    }
90
91    /// If no velocities is specified or zero velocities is specified at the last point,
92    /// trajectory points are interpolated, otherwise simply input trajectory is forwarded to client.
93    fn send_joint_trajectory(&self, trajectory: Vec<TrajectoryPoint>) -> Result<WaitFuture, Error> {
94        let fixed_trajectory = if should_interpolate_joint_trajectory(&trajectory) {
95            interpolate_joint_trajectory(
96                self.client.current_joint_positions()?,
97                &self.position_difference_limits,
98                trajectory,
99            )?
100        } else {
101            trajectory
102        };
103        self.client.send_joint_trajectory(fixed_trajectory)
104    }
105}
106
107fn interpolate(
108    mut current: Vec<f64>,
109    position_difference_limits: &[f64],
110    positions: &[f64],
111    first_time_from_start: &Duration,
112    last_time_from_start: &Duration,
113) -> Result<Option<Vec<TrajectoryPoint>>, Error> {
114    let mut max_diff_step: f64 = 0.0;
115    let mut diff = vec![0.0; current.len()];
116    for (i, p) in current.iter().enumerate() {
117        diff[i] = positions[i] - p;
118        let step = diff[i].abs() / position_difference_limits[i].abs();
119        if step.is_infinite() {
120            return Err(Error::Other(anyhow::format_err!(
121                "Invalid position difference limits {} for joint {i} ",
122                position_difference_limits[i],
123            )));
124        }
125        max_diff_step = max_diff_step.max(step);
126    }
127    let max_diff_step = max_diff_step.ceil() as usize;
128    Ok(if max_diff_step <= 1 {
129        None
130    } else {
131        diff.iter_mut().for_each(|d| *d /= max_diff_step as f64);
132        let step_duration = Duration::from_secs_f64(
133            (last_time_from_start.as_secs_f64() - first_time_from_start.as_secs_f64())
134                / max_diff_step as f64,
135        );
136        let mut trajectory = vec![];
137        for i in 1..max_diff_step {
138            current
139                .iter_mut()
140                .enumerate()
141                .for_each(|(i, c)| *c += diff[i]);
142            trajectory.push(TrajectoryPoint {
143                positions: current.to_owned(),
144                velocities: None,
145                time_from_start: *first_time_from_start + step_duration * i as u32,
146            })
147        }
148        trajectory.push(TrajectoryPoint {
149            positions: positions.to_vec(),
150            velocities: None,
151            time_from_start: *last_time_from_start,
152        });
153        Some(trajectory)
154    })
155}
156
157fn should_interpolate_joint_trajectory(trajectory: &[TrajectoryPoint]) -> bool {
158    if trajectory.is_empty() {
159        return false;
160    };
161    match trajectory.iter().position(|p| p.velocities.is_some()) {
162        Some(first_index_of_valid_velocity) => {
163            let last_index = trajectory.len() - 1;
164            if first_index_of_valid_velocity != last_index {
165                false
166            } else {
167                !trajectory[last_index]
168                    .velocities
169                    .as_ref()
170                    .unwrap()
171                    .iter()
172                    .any(|x| x.abs() > ZERO_VELOCITY_THRESHOLD)
173            }
174        }
175        None => true,
176    }
177}
178
179fn interpolate_joint_trajectory(
180    current: Vec<f64>,
181    position_difference_limits: &[f64],
182    trajectory: Vec<TrajectoryPoint>,
183) -> Result<Vec<TrajectoryPoint>, Error> {
184    let mut fixed_trajectory = vec![];
185    let mut previous_joint_positions = current;
186    let mut previous_time_from_start = Duration::from_secs(0);
187
188    for p in trajectory {
189        let target = p.positions.clone();
190        let velocity = p.velocities.clone();
191        let time_from_start = p.time_from_start;
192        fixed_trajectory.extend(
193            match interpolate(
194                previous_joint_positions,
195                position_difference_limits,
196                &target,
197                &previous_time_from_start,
198                &time_from_start,
199            )? {
200                Some(mut interpolated) => {
201                    interpolated.last_mut().unwrap().velocities = velocity;
202                    interpolated
203                }
204                None => vec![p],
205            },
206        );
207
208        previous_joint_positions = target;
209        previous_time_from_start = time_from_start;
210    }
211    Ok(fixed_trajectory)
212}
213
214#[cfg(test)]
215mod test {
216    use std::{sync::Arc, time::Duration};
217
218    use assert_approx_eq::assert_approx_eq;
219
220    use super::{
221        JointPositionDifferenceLimiter, JointTrajectoryClient, TrajectoryPoint, interpolate,
222        interpolate_joint_trajectory, should_interpolate_joint_trajectory,
223    };
224    use crate::DummyJointTrajectoryClient;
225
226    #[test]
227    fn interpolate_no_interpolation() {
228        let interpolated = interpolate(
229            vec![0.0, 1.0],
230            &[1.0, 1.0],
231            &[-1.0, 2.0],
232            &Duration::from_secs(0),
233            &Duration::from_secs(1),
234        );
235        assert!(interpolated.is_ok());
236        assert!(interpolated.unwrap().is_none());
237        assert!(
238            interpolate(
239                vec![0.0, 1.0],
240                &[0.0, 0.0],
241                &[-1.0, 2.0],
242                &Duration::from_secs(0),
243                &Duration::from_secs(1),
244            )
245            .is_err()
246        );
247    }
248    #[test]
249    fn interpolate_interpolated() {
250        let interpolated = interpolate(
251            vec![0.0, 1.0],
252            &[1.0, 0.5],
253            &[-1.0, 2.0],
254            &Duration::from_secs(0),
255            &Duration::from_secs(1),
256        );
257        assert!(interpolated.is_ok());
258        let interpolated = interpolated.unwrap();
259        assert!(interpolated.is_some());
260        let interpolated = interpolated.unwrap();
261        assert_eq!(interpolated.len(), 2);
262        assert_eq!(interpolated[0].positions, vec![-0.5, 1.5]);
263        assert!(interpolated[0].velocities.is_none());
264        assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 0.5);
265
266        assert_eq!(interpolated[1].positions, vec![-1.0, 2.0]);
267        assert!(interpolated[1].velocities.is_none());
268        assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 1.0);
269    }
270
271    #[test]
272    fn joint_position_difference_limiter_new_error() {
273        let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
274            "a".to_owned(),
275            "b".to_owned(),
276        ]));
277        assert!(
278            JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![3.0, 1.0, 2.0])
279                .is_err()
280        );
281        assert!(JointPositionDifferenceLimiter::new(wrapped_client, vec![1.0, 0.0]).is_err());
282    }
283    #[test]
284    fn joint_position_difference_limiter_send_joint_trajectory() {
285        let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
286            "a".to_owned(),
287            "b".to_owned(),
288        ]));
289        let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, 2.0]);
290        assert!(client.is_ok());
291        let client = client.unwrap();
292        assert_eq!(
293            client.joint_names().len(),
294            wrapped_client.joint_names().len()
295        );
296        for (c, w) in client
297            .joint_names()
298            .iter()
299            .zip(wrapped_client.joint_names().iter())
300        {
301            assert_eq!(c, w);
302        }
303
304        let trajectory = vec![
305            TrajectoryPoint {
306                positions: vec![1.0, 2.0],
307                velocities: Some(vec![3.0, 4.0]),
308                time_from_start: std::time::Duration::from_secs_f64(4.0),
309            },
310            TrajectoryPoint {
311                positions: vec![3.0, 6.0],
312                velocities: Some(vec![3.0, 4.0]),
313                time_from_start: std::time::Duration::from_secs_f64(8.0),
314            },
315        ];
316        assert!(
317            tokio_test::block_on(client.send_joint_trajectory(trajectory.clone()).unwrap()).is_ok()
318        );
319        for (c, w) in trajectory.iter().zip(
320            wrapped_client
321                .last_trajectory
322                .lock()
323                .unwrap()
324                .clone()
325                .iter(),
326        ) {
327            assert_eq!(c.positions, w.positions);
328            assert_eq!(c.velocities, w.velocities);
329            assert_eq!(c.time_from_start, w.time_from_start);
330        }
331        assert_eq!(
332            trajectory.last().unwrap().positions,
333            client.current_joint_positions().unwrap()
334        );
335    }
336    #[test]
337    fn joint_position_difference_limiter_send_joint_positions_no_interpolation() {
338        let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
339            "a".to_owned(),
340            "b".to_owned(),
341        ]));
342        let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, 1.0]);
343        assert!(client.is_ok());
344        let client = client.unwrap();
345        assert_eq!(
346            client.joint_names().len(),
347            wrapped_client.joint_names().len()
348        );
349        for (c, w) in client
350            .joint_names()
351            .iter()
352            .zip(wrapped_client.joint_names().iter())
353        {
354            assert_eq!(c, w);
355        }
356        *wrapped_client.positions.lock().unwrap() = vec![0.0, 1.0];
357        assert!(
358            tokio_test::block_on(
359                client
360                    .send_joint_positions(vec![-1.0, 2.0], Duration::from_secs(1))
361                    .unwrap()
362            )
363            .is_ok()
364        );
365        assert!(wrapped_client.last_trajectory.lock().unwrap().is_empty());
366        assert_eq!(
367            wrapped_client.current_joint_positions().unwrap(),
368            vec![-1.0, 2.0]
369        );
370    }
371    #[test]
372    fn joint_position_difference_limiter_send_joint_positions_interpolated() {
373        let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
374            "a".to_owned(),
375            "b".to_owned(),
376        ]));
377        let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, -0.5]);
378        assert!(client.is_ok());
379        let client = client.unwrap();
380        assert_eq!(
381            client.joint_names().len(),
382            wrapped_client.joint_names().len()
383        );
384        for (c, w) in client
385            .joint_names()
386            .iter()
387            .zip(wrapped_client.joint_names().iter())
388        {
389            assert_eq!(c, w);
390        }
391        *wrapped_client.positions.lock().unwrap() = vec![0.0, 1.0];
392        assert!(
393            tokio_test::block_on(
394                client
395                    .send_joint_positions(vec![-1.0, 2.0], Duration::from_secs(1))
396                    .unwrap()
397            )
398            .is_ok()
399        );
400        let actual_trajectory = wrapped_client.last_trajectory.lock().unwrap().clone();
401        assert_eq!(actual_trajectory.len(), 2);
402        assert_eq!(actual_trajectory[0].positions, vec![-0.5, 1.5]);
403        assert_eq!(actual_trajectory[1].positions, vec![-1.0, 2.0]);
404        assert!(actual_trajectory[0].velocities.is_none());
405        assert!(actual_trajectory[1].velocities.is_none());
406        assert_approx_eq!(actual_trajectory[0].time_from_start.as_secs_f64(), 0.5);
407        assert_approx_eq!(actual_trajectory[1].time_from_start.as_secs_f64(), 1.0);
408
409        assert_eq!(
410            wrapped_client.current_joint_positions().unwrap(),
411            vec![-1.0, 2.0]
412        );
413    }
414
415    #[test]
416    fn test_should_interpolate_joint_trajectory() {
417        assert!(!should_interpolate_joint_trajectory(&[]));
418        assert!(!should_interpolate_joint_trajectory(&[
419            TrajectoryPoint {
420                positions: vec![],
421                velocities: Some(vec![0.0, 0.0]),
422                time_from_start: std::time::Duration::from_secs(0),
423            },
424            TrajectoryPoint {
425                positions: vec![],
426                velocities: Some(vec![0.0, 0.0]),
427                time_from_start: std::time::Duration::from_secs(0),
428            }
429        ]));
430        assert!(!should_interpolate_joint_trajectory(&[
431            TrajectoryPoint {
432                positions: vec![],
433                velocities: None,
434                time_from_start: std::time::Duration::from_secs(0),
435            },
436            TrajectoryPoint {
437                positions: vec![],
438                velocities: Some(vec![0.0, 0.01]),
439                time_from_start: std::time::Duration::from_secs(0),
440            }
441        ]));
442        assert!(should_interpolate_joint_trajectory(&[
443            TrajectoryPoint {
444                positions: vec![],
445                velocities: None,
446                time_from_start: std::time::Duration::from_secs(0),
447            },
448            TrajectoryPoint {
449                positions: vec![],
450                velocities: Some(vec![0.0, 0.0]),
451                time_from_start: std::time::Duration::from_secs(0),
452            }
453        ]));
454    }
455
456    #[test]
457    fn test_should_interpolate_joint_trajectory_no_interpolation() {
458        let interpolated = interpolate_joint_trajectory(
459            vec![0.0, 1.0],
460            &[1.0, 1.0],
461            vec![
462                TrajectoryPoint {
463                    positions: vec![-1.0, 2.0],
464                    velocities: None,
465                    time_from_start: std::time::Duration::from_secs(1),
466                },
467                TrajectoryPoint {
468                    positions: vec![-2.0, 3.0],
469                    velocities: Some(vec![0.0, 0.0]),
470                    time_from_start: std::time::Duration::from_secs(2),
471                },
472            ],
473        );
474        assert!(interpolated.is_ok());
475        let interpolated = interpolated.unwrap();
476
477        assert_eq!(interpolated.len(), 2);
478
479        assert_eq!(interpolated[0].positions, vec![-1.0, 2.0]);
480        assert!(interpolated[0].velocities.is_none());
481        assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 1.0);
482
483        assert_eq!(interpolated[1].positions, vec![-2.0, 3.0]);
484        assert!(interpolated[1].velocities.is_some());
485        assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 2.0);
486    }
487
488    #[test]
489    fn test_should_interpolate_joint_trajectory_interpolated() {
490        let interpolated = interpolate_joint_trajectory(
491            vec![0.0, 1.0],
492            &[1.0, 0.5],
493            vec![
494                TrajectoryPoint {
495                    positions: vec![-1.0, 2.0],
496                    velocities: None,
497                    time_from_start: std::time::Duration::from_secs(1),
498                },
499                TrajectoryPoint {
500                    positions: vec![-2.0, 3.0],
501                    velocities: Some(vec![0.0, 0.0]),
502                    time_from_start: std::time::Duration::from_secs(2),
503                },
504            ],
505        );
506        assert!(interpolated.is_ok());
507        let interpolated = interpolated.unwrap();
508
509        assert_eq!(interpolated.len(), 4);
510
511        assert_eq!(interpolated[0].positions, vec![-0.5, 1.5]);
512        assert!(interpolated[0].velocities.is_none());
513        assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 0.5);
514
515        assert_eq!(interpolated[1].positions, vec![-1.0, 2.0]);
516        assert!(interpolated[1].velocities.is_none());
517        assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 1.0);
518
519        assert_eq!(interpolated[2].positions, vec![-1.5, 2.5]);
520        assert!(interpolated[2].velocities.is_none());
521        assert_approx_eq!(interpolated[2].time_from_start.as_secs_f64(), 1.5);
522
523        assert_eq!(interpolated[3].positions, vec![-2.0, 3.0]);
524        assert!(interpolated[3].velocities.is_some());
525        assert_approx_eq!(interpolated[3].time_from_start.as_secs_f64(), 2.0);
526    }
527}