arci/clients/
joint_position_difference_limiter.rs

1use std::time::Duration;
2
3use crate::{error::Error, traits::JointTrajectoryClient, TrajectoryPoint, WaitFuture};
4
5const ZERO_VELOCITY_THRESHOLD: f64 = 1.0e-6;
6
7/// JointPositionDifferenceLimiter limits the difference of position between trajectory points and
8///  trajectory points are interpolated linearly to satisfy the limits
9///  in JointTrajectoryClient::send_joint_positions.
10/// In send_joint_trajectory, if no velocities is specified or zero velocities is specified at the
11///  last point, trajectory points are interpolated, otherwise simply input trajectory is forwarded to client.
12
13#[derive(Debug)]
14pub struct JointPositionDifferenceLimiter<C>
15where
16    C: JointTrajectoryClient,
17{
18    client: C,
19    position_difference_limits: Vec<f64>,
20}
21
22impl<C> JointPositionDifferenceLimiter<C>
23where
24    C: JointTrajectoryClient,
25{
26    /// Create a new `JointPositionDifferenceLimiter` with the given position difference limits.
27    pub fn new(client: C, mut position_difference_limits: Vec<f64>) -> Result<Self, Error> {
28        if client.joint_names().len() == position_difference_limits.len() {
29            let mut is_valid = true;
30            position_difference_limits.iter_mut().for_each(|f| {
31                let f_abs = f.abs();
32                if f_abs < f64::MIN_POSITIVE {
33                    is_valid = false;
34                }
35                *f = f_abs
36            });
37            if !is_valid {
38                return Err(Error::Other(anyhow::format_err!(
39                    "Too small position difference limit",
40                )));
41            }
42            Ok(Self {
43                client,
44                position_difference_limits,
45            })
46        } else {
47            Err(Error::LengthMismatch {
48                model: client.joint_names().len(),
49                input: position_difference_limits.len(),
50            })
51        }
52    }
53}
54
55impl<C> JointTrajectoryClient for JointPositionDifferenceLimiter<C>
56where
57    C: JointTrajectoryClient,
58{
59    fn joint_names(&self) -> Vec<String> {
60        self.client.joint_names()
61    }
62
63    fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
64        self.client.current_joint_positions()
65    }
66
67    fn send_joint_positions(
68        &self,
69        positions: Vec<f64>,
70        duration: Duration,
71    ) -> Result<WaitFuture, Error> {
72        let current = self.client.current_joint_positions()?;
73        if current.len() != positions.len() {
74            return Err(Error::LengthMismatch {
75                model: positions.len(),
76                input: current.len(),
77            });
78        }
79        match interpolate(
80            current,
81            &self.position_difference_limits,
82            &positions,
83            &Duration::from_secs(0),
84            &duration,
85        )? {
86            Some(trajectory) => self.client.send_joint_trajectory(trajectory),
87            None => self.client.send_joint_positions(positions, duration),
88        }
89    }
90
91    /// If no velocities is specified or zero velocities is specified at the last point,
92    /// trajectory points are interpolated, otherwise simply input trajectory is forwarded to client.
93    fn send_joint_trajectory(&self, trajectory: Vec<TrajectoryPoint>) -> Result<WaitFuture, Error> {
94        let fixed_trajectory = if should_interpolate_joint_trajectory(&trajectory) {
95            interpolate_joint_trajectory(
96                self.client.current_joint_positions()?,
97                &self.position_difference_limits,
98                trajectory,
99            )?
100        } else {
101            trajectory
102        };
103        self.client.send_joint_trajectory(fixed_trajectory)
104    }
105}
106
107fn interpolate(
108    mut current: Vec<f64>,
109    position_difference_limits: &[f64],
110    positions: &[f64],
111    first_time_from_start: &Duration,
112    last_time_from_start: &Duration,
113) -> Result<Option<Vec<TrajectoryPoint>>, Error> {
114    let mut max_diff_step: f64 = 0.0;
115    let mut diff = vec![0.0; current.len()];
116    for (i, p) in current.iter().enumerate() {
117        diff[i] = positions[i] - p;
118        let step = diff[i].abs() / position_difference_limits[i].abs();
119        if step.is_infinite() {
120            return Err(Error::Other(anyhow::format_err!(
121                "Invalid position difference limits {} for joint {i} ",
122                position_difference_limits[i],
123            )));
124        }
125        max_diff_step = max_diff_step.max(step);
126    }
127    let max_diff_step = max_diff_step.ceil() as usize;
128    Ok(if max_diff_step <= 1 {
129        None
130    } else {
131        diff.iter_mut().for_each(|d| *d /= max_diff_step as f64);
132        let step_duration = Duration::from_secs_f64(
133            (last_time_from_start.as_secs_f64() - first_time_from_start.as_secs_f64())
134                / max_diff_step as f64,
135        );
136        let mut trajectory = vec![];
137        for i in 1..max_diff_step {
138            current
139                .iter_mut()
140                .enumerate()
141                .for_each(|(i, c)| *c += diff[i]);
142            trajectory.push(TrajectoryPoint {
143                positions: current.to_owned(),
144                velocities: None,
145                time_from_start: *first_time_from_start + step_duration * i as u32,
146            })
147        }
148        trajectory.push(TrajectoryPoint {
149            positions: positions.to_vec(),
150            velocities: None,
151            time_from_start: *last_time_from_start,
152        });
153        Some(trajectory)
154    })
155}
156
157fn should_interpolate_joint_trajectory(trajectory: &[TrajectoryPoint]) -> bool {
158    if trajectory.is_empty() {
159        return false;
160    };
161    match trajectory.iter().position(|p| p.velocities.is_some()) {
162        Some(first_index_of_valid_velocity) => {
163            let last_index = trajectory.len() - 1;
164            if first_index_of_valid_velocity != last_index {
165                false
166            } else {
167                !trajectory[last_index]
168                    .velocities
169                    .as_ref()
170                    .unwrap()
171                    .iter()
172                    .any(|x| x.abs() > ZERO_VELOCITY_THRESHOLD)
173            }
174        }
175        None => true,
176    }
177}
178
179fn interpolate_joint_trajectory(
180    current: Vec<f64>,
181    position_difference_limits: &[f64],
182    trajectory: Vec<TrajectoryPoint>,
183) -> Result<Vec<TrajectoryPoint>, Error> {
184    let mut fixed_trajectory = vec![];
185    let mut previous_joint_positions = current;
186    let mut previous_time_from_start = Duration::from_secs(0);
187
188    for p in trajectory {
189        let target = p.positions.clone();
190        let velocity = p.velocities.clone();
191        let time_from_start = p.time_from_start;
192        fixed_trajectory.extend(
193            match interpolate(
194                previous_joint_positions,
195                position_difference_limits,
196                &target,
197                &previous_time_from_start,
198                &time_from_start,
199            )? {
200                Some(mut interpolated) => {
201                    interpolated.last_mut().unwrap().velocities = velocity;
202                    interpolated
203                }
204                None => vec![p],
205            },
206        );
207
208        previous_joint_positions = target;
209        previous_time_from_start = time_from_start;
210    }
211    Ok(fixed_trajectory)
212}
213
214#[cfg(test)]
215mod test {
216    use std::{sync::Arc, time::Duration};
217
218    use assert_approx_eq::assert_approx_eq;
219
220    use super::{
221        interpolate, interpolate_joint_trajectory, should_interpolate_joint_trajectory,
222        JointPositionDifferenceLimiter, JointTrajectoryClient, TrajectoryPoint,
223    };
224    use crate::DummyJointTrajectoryClient;
225
226    #[test]
227    fn interpolate_no_interpolation() {
228        let interpolated = interpolate(
229            vec![0.0, 1.0],
230            &[1.0, 1.0],
231            &[-1.0, 2.0],
232            &Duration::from_secs(0),
233            &Duration::from_secs(1),
234        );
235        assert!(interpolated.is_ok());
236        assert!(interpolated.unwrap().is_none());
237        assert!(interpolate(
238            vec![0.0, 1.0],
239            &[0.0, 0.0],
240            &[-1.0, 2.0],
241            &Duration::from_secs(0),
242            &Duration::from_secs(1),
243        )
244        .is_err());
245    }
246    #[test]
247    fn interpolate_interpolated() {
248        let interpolated = interpolate(
249            vec![0.0, 1.0],
250            &[1.0, 0.5],
251            &[-1.0, 2.0],
252            &Duration::from_secs(0),
253            &Duration::from_secs(1),
254        );
255        assert!(interpolated.is_ok());
256        let interpolated = interpolated.unwrap();
257        assert!(interpolated.is_some());
258        let interpolated = interpolated.unwrap();
259        assert_eq!(interpolated.len(), 2);
260        assert_eq!(interpolated[0].positions, vec![-0.5, 1.5]);
261        assert!(interpolated[0].velocities.is_none());
262        assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 0.5);
263
264        assert_eq!(interpolated[1].positions, vec![-1.0, 2.0]);
265        assert!(interpolated[1].velocities.is_none());
266        assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 1.0);
267    }
268
269    #[test]
270    fn joint_position_difference_limiter_new_error() {
271        let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
272            "a".to_owned(),
273            "b".to_owned(),
274        ]));
275        assert!(
276            JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![3.0, 1.0, 2.0])
277                .is_err()
278        );
279        assert!(JointPositionDifferenceLimiter::new(wrapped_client, vec![1.0, 0.0]).is_err());
280    }
281    #[test]
282    fn joint_position_difference_limiter_send_joint_trajectory() {
283        let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
284            "a".to_owned(),
285            "b".to_owned(),
286        ]));
287        let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, 2.0]);
288        assert!(client.is_ok());
289        let client = client.unwrap();
290        assert_eq!(
291            client.joint_names().len(),
292            wrapped_client.joint_names().len()
293        );
294        for (c, w) in client
295            .joint_names()
296            .iter()
297            .zip(wrapped_client.joint_names().iter())
298        {
299            assert_eq!(c, w);
300        }
301
302        let trajectory = vec![
303            TrajectoryPoint {
304                positions: vec![1.0, 2.0],
305                velocities: Some(vec![3.0, 4.0]),
306                time_from_start: std::time::Duration::from_secs_f64(4.0),
307            },
308            TrajectoryPoint {
309                positions: vec![3.0, 6.0],
310                velocities: Some(vec![3.0, 4.0]),
311                time_from_start: std::time::Duration::from_secs_f64(8.0),
312            },
313        ];
314        assert!(
315            tokio_test::block_on(client.send_joint_trajectory(trajectory.clone()).unwrap()).is_ok()
316        );
317        for (c, w) in trajectory.iter().zip(
318            wrapped_client
319                .last_trajectory
320                .lock()
321                .unwrap()
322                .clone()
323                .iter(),
324        ) {
325            assert_eq!(c.positions, w.positions);
326            assert_eq!(c.velocities, w.velocities);
327            assert_eq!(c.time_from_start, w.time_from_start);
328        }
329        assert_eq!(
330            trajectory.last().unwrap().positions,
331            client.current_joint_positions().unwrap()
332        );
333    }
334    #[test]
335    fn joint_position_difference_limiter_send_joint_positions_no_interpolation() {
336        let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
337            "a".to_owned(),
338            "b".to_owned(),
339        ]));
340        let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, 1.0]);
341        assert!(client.is_ok());
342        let client = client.unwrap();
343        assert_eq!(
344            client.joint_names().len(),
345            wrapped_client.joint_names().len()
346        );
347        for (c, w) in client
348            .joint_names()
349            .iter()
350            .zip(wrapped_client.joint_names().iter())
351        {
352            assert_eq!(c, w);
353        }
354        *wrapped_client.positions.lock().unwrap() = vec![0.0, 1.0];
355        assert!(tokio_test::block_on(
356            client
357                .send_joint_positions(vec![-1.0, 2.0], Duration::from_secs(1))
358                .unwrap()
359        )
360        .is_ok());
361        assert!(wrapped_client.last_trajectory.lock().unwrap().is_empty());
362        assert_eq!(
363            wrapped_client.current_joint_positions().unwrap(),
364            vec![-1.0, 2.0]
365        );
366    }
367    #[test]
368    fn joint_position_difference_limiter_send_joint_positions_interpolated() {
369        let wrapped_client = Arc::new(DummyJointTrajectoryClient::new(vec![
370            "a".to_owned(),
371            "b".to_owned(),
372        ]));
373        let client = JointPositionDifferenceLimiter::new(wrapped_client.clone(), vec![1.0, -0.5]);
374        assert!(client.is_ok());
375        let client = client.unwrap();
376        assert_eq!(
377            client.joint_names().len(),
378            wrapped_client.joint_names().len()
379        );
380        for (c, w) in client
381            .joint_names()
382            .iter()
383            .zip(wrapped_client.joint_names().iter())
384        {
385            assert_eq!(c, w);
386        }
387        *wrapped_client.positions.lock().unwrap() = vec![0.0, 1.0];
388        assert!(tokio_test::block_on(
389            client
390                .send_joint_positions(vec![-1.0, 2.0], Duration::from_secs(1))
391                .unwrap()
392        )
393        .is_ok());
394        let actual_trajectory = wrapped_client.last_trajectory.lock().unwrap().clone();
395        assert_eq!(actual_trajectory.len(), 2);
396        assert_eq!(actual_trajectory[0].positions, vec![-0.5, 1.5]);
397        assert_eq!(actual_trajectory[1].positions, vec![-1.0, 2.0]);
398        assert!(actual_trajectory[0].velocities.is_none());
399        assert!(actual_trajectory[1].velocities.is_none());
400        assert_approx_eq!(actual_trajectory[0].time_from_start.as_secs_f64(), 0.5);
401        assert_approx_eq!(actual_trajectory[1].time_from_start.as_secs_f64(), 1.0);
402
403        assert_eq!(
404            wrapped_client.current_joint_positions().unwrap(),
405            vec![-1.0, 2.0]
406        );
407    }
408
409    #[test]
410    fn test_should_interpolate_joint_trajectory() {
411        assert!(!should_interpolate_joint_trajectory(&[]));
412        assert!(!should_interpolate_joint_trajectory(&[
413            TrajectoryPoint {
414                positions: vec![],
415                velocities: Some(vec![0.0, 0.0]),
416                time_from_start: std::time::Duration::from_secs(0),
417            },
418            TrajectoryPoint {
419                positions: vec![],
420                velocities: Some(vec![0.0, 0.0]),
421                time_from_start: std::time::Duration::from_secs(0),
422            }
423        ]));
424        assert!(!should_interpolate_joint_trajectory(&[
425            TrajectoryPoint {
426                positions: vec![],
427                velocities: None,
428                time_from_start: std::time::Duration::from_secs(0),
429            },
430            TrajectoryPoint {
431                positions: vec![],
432                velocities: Some(vec![0.0, 0.01]),
433                time_from_start: std::time::Duration::from_secs(0),
434            }
435        ]));
436        assert!(should_interpolate_joint_trajectory(&[
437            TrajectoryPoint {
438                positions: vec![],
439                velocities: None,
440                time_from_start: std::time::Duration::from_secs(0),
441            },
442            TrajectoryPoint {
443                positions: vec![],
444                velocities: Some(vec![0.0, 0.0]),
445                time_from_start: std::time::Duration::from_secs(0),
446            }
447        ]));
448    }
449
450    #[test]
451    fn test_should_interpolate_joint_trajectory_no_interpolation() {
452        let interpolated = interpolate_joint_trajectory(
453            vec![0.0, 1.0],
454            &[1.0, 1.0],
455            vec![
456                TrajectoryPoint {
457                    positions: vec![-1.0, 2.0],
458                    velocities: None,
459                    time_from_start: std::time::Duration::from_secs(1),
460                },
461                TrajectoryPoint {
462                    positions: vec![-2.0, 3.0],
463                    velocities: Some(vec![0.0, 0.0]),
464                    time_from_start: std::time::Duration::from_secs(2),
465                },
466            ],
467        );
468        assert!(interpolated.is_ok());
469        let interpolated = interpolated.unwrap();
470
471        assert_eq!(interpolated.len(), 2);
472
473        assert_eq!(interpolated[0].positions, vec![-1.0, 2.0]);
474        assert!(interpolated[0].velocities.is_none());
475        assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 1.0);
476
477        assert_eq!(interpolated[1].positions, vec![-2.0, 3.0]);
478        assert!(interpolated[1].velocities.is_some());
479        assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 2.0);
480    }
481
482    #[test]
483    fn test_should_interpolate_joint_trajectory_interpolated() {
484        let interpolated = interpolate_joint_trajectory(
485            vec![0.0, 1.0],
486            &[1.0, 0.5],
487            vec![
488                TrajectoryPoint {
489                    positions: vec![-1.0, 2.0],
490                    velocities: None,
491                    time_from_start: std::time::Duration::from_secs(1),
492                },
493                TrajectoryPoint {
494                    positions: vec![-2.0, 3.0],
495                    velocities: Some(vec![0.0, 0.0]),
496                    time_from_start: std::time::Duration::from_secs(2),
497                },
498            ],
499        );
500        assert!(interpolated.is_ok());
501        let interpolated = interpolated.unwrap();
502
503        assert_eq!(interpolated.len(), 4);
504
505        assert_eq!(interpolated[0].positions, vec![-0.5, 1.5]);
506        assert!(interpolated[0].velocities.is_none());
507        assert_approx_eq!(interpolated[0].time_from_start.as_secs_f64(), 0.5);
508
509        assert_eq!(interpolated[1].positions, vec![-1.0, 2.0]);
510        assert!(interpolated[1].velocities.is_none());
511        assert_approx_eq!(interpolated[1].time_from_start.as_secs_f64(), 1.0);
512
513        assert_eq!(interpolated[2].positions, vec![-1.5, 2.5]);
514        assert!(interpolated[2].velocities.is_none());
515        assert_approx_eq!(interpolated[2].time_from_start.as_secs_f64(), 1.5);
516
517        assert_eq!(interpolated[3].positions, vec![-2.0, 3.0]);
518        assert!(interpolated[3].velocities.is_some());
519        assert_approx_eq!(interpolated[3].time_from_start.as_secs_f64(), 2.0);
520    }
521}