arci/clients/
joint_velocity_limiter.rs

1use tracing::debug;
2
3use crate::{
4    error::Error,
5    traits::{JointTrajectoryClient, TrajectoryPoint},
6    waits::WaitFuture,
7};
8
9/// JointVelocityLimiter limits the duration to make all joints velocities lower than the given
10/// velocities limits at each TrajectoryPoint.
11///
12/// It does not change TrajectoryPoint velocities.
13/// The duration for a TrajectoryPoint\[i\] is set to
14/// ```Text
15/// duration[i] = max(limited_duration_i[j=0], ...,  limited_duration_i[j=J-1], input_duration[i])
16/// where
17///  j : joint_index (0 <= j < J),
18///  limited_duration_i[j] =
19///   abs(TrajectoryPoint[i].positions[j]  - TrajectoryPoint[i-1].positions[j]) / velocity_limits[j]
20/// ```
21#[derive(Debug)]
22pub struct JointVelocityLimiter<C>
23where
24    C: JointTrajectoryClient,
25{
26    client: C,
27    velocity_limits: Vec<f64>,
28}
29
30impl<C> JointVelocityLimiter<C>
31where
32    C: JointTrajectoryClient,
33{
34    /// Creates a new `JointVelocityLimiter` with the given velocity limits.
35    ///
36    /// # Panics
37    ///
38    /// Panics if the lengths of `velocity_limits` and joints that `client` handles are different.
39    #[track_caller]
40    pub fn new(client: C, velocity_limits: Vec<f64>) -> Self {
41        assert!(client.joint_names().len() == velocity_limits.len());
42        Self {
43            client,
44            velocity_limits,
45        }
46    }
47
48    /// Creates a new `JointVelocityLimiter` with the velocity limits defined in URDF.
49    pub fn from_urdf(client: C, joints: &[urdf_rs::Joint]) -> Result<Self, Error> {
50        let mut velocity_limits = Vec::new();
51        for joint_name in client.joint_names() {
52            if let Some(i) = joints.iter().position(|j| j.name == *joint_name) {
53                let limit = joints[i].limit.velocity;
54                velocity_limits.push(limit);
55            } else {
56                return Err(Error::NoJoint(joint_name));
57            }
58        }
59
60        Ok(Self {
61            client,
62            velocity_limits,
63        })
64    }
65}
66
67impl<C> JointTrajectoryClient for JointVelocityLimiter<C>
68where
69    C: JointTrajectoryClient,
70{
71    fn joint_names(&self) -> Vec<String> {
72        self.client.joint_names()
73    }
74
75    fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
76        self.client.current_joint_positions()
77    }
78
79    fn send_joint_positions(
80        &self,
81        positions: Vec<f64>,
82        duration: std::time::Duration,
83    ) -> Result<WaitFuture, Error> {
84        self.send_joint_trajectory(vec![TrajectoryPoint {
85            positions,
86            velocities: None,
87            time_from_start: duration,
88        }])
89    }
90
91    fn send_joint_trajectory(&self, trajectory: Vec<TrajectoryPoint>) -> Result<WaitFuture, Error> {
92        let mut prev_positions = self.current_joint_positions()?;
93
94        let mut limited_trajectory = vec![];
95        let mut limited_duration_from_start = std::time::Duration::from_secs(0);
96        let mut original_duration_from_start = std::time::Duration::from_secs(0);
97        for (sequence_index, original_trajectory_point) in trajectory.iter().enumerate() {
98            let mut limited_duration_from_prev = std::time::Duration::from_secs(0);
99            let mut dominant_joint_index = 0;
100            for (joint_index, prev_position) in prev_positions.iter().enumerate() {
101                let single_duration = std::time::Duration::from_secs_f64(
102                    (prev_position - original_trajectory_point.positions[joint_index]).abs()
103                        / self.velocity_limits[joint_index],
104                );
105                limited_duration_from_prev = if single_duration > limited_duration_from_prev {
106                    dominant_joint_index = joint_index;
107                    single_duration
108                } else {
109                    limited_duration_from_prev
110                }
111            }
112            let original_duration_from_prev =
113                original_trajectory_point.time_from_start - original_duration_from_start;
114            original_duration_from_start = original_trajectory_point.time_from_start;
115
116            let use_limited = limited_duration_from_prev > original_duration_from_prev;
117            let selected_duration = if use_limited {
118                limited_duration_from_prev
119            } else {
120                original_duration_from_prev
121            };
122            limited_duration_from_start += selected_duration;
123            limited_trajectory.push(TrajectoryPoint {
124                positions: original_trajectory_point.positions.clone(),
125                velocities: original_trajectory_point.velocities.clone(),
126                time_from_start: limited_duration_from_start,
127            });
128            prev_positions.clone_from(&original_trajectory_point.positions);
129            debug!(
130                "Sequence{sequence_index} dominant joint_index {dominant_joint_index} duration limited : {limited_duration_from_prev:?}{} original : {original_duration_from_prev:?}{}",
131                if use_limited { "(O)" } else { "" },
132                if use_limited { "" } else { "(O)" }
133            );
134        }
135
136        debug!("OriginalTrajectory {trajectory:?}");
137        debug!("LimitedTrajectory {limited_trajectory:?}");
138
139        self.client.send_joint_trajectory(limited_trajectory)
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::sync::Arc;
146
147    use assert_approx_eq::assert_approx_eq;
148
149    use super::*;
150    use crate::DummyJointTrajectoryClient;
151    #[test]
152    #[should_panic]
153    fn mismatch_size() {
154        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
155        JointVelocityLimiter::new(client, vec![1.0, 2.0]);
156    }
157    #[test]
158    fn joint_names() {
159        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned(), "b".to_owned()]);
160        let limiter = JointVelocityLimiter::new(client, vec![1.0, 2.0]);
161        let joint_names = limiter.joint_names();
162        assert_eq!(joint_names.len(), 2);
163        assert_eq!(joint_names[0], "a");
164        assert_eq!(joint_names[1], "b");
165    }
166    fn test_send_joint_positions(limits: Vec<f64>, expected_duration_secs: f64) {
167        let client = Arc::new(DummyJointTrajectoryClient::new(vec![
168            "a".to_owned(),
169            "b".to_owned(),
170        ]));
171        let limiter = JointVelocityLimiter::new(client.clone(), limits);
172        assert!(
173            tokio_test::block_on(
174                limiter
175                    .send_joint_positions(vec![1.0, 2.0], std::time::Duration::from_secs_f64(4.0))
176                    .unwrap()
177            )
178            .is_ok()
179        );
180        let joint_positions = limiter.current_joint_positions().unwrap();
181        assert_eq!(joint_positions.len(), 2);
182        assert_approx_eq!(joint_positions[0], 1.0);
183        assert_approx_eq!(joint_positions[1], 2.0);
184        let trajectory = client.last_trajectory.lock().unwrap();
185        assert_eq!(trajectory.len(), 1);
186        assert_eq!(trajectory[0].positions.len(), 2);
187        assert_approx_eq!(trajectory[0].positions[0], 1.0);
188        assert_approx_eq!(trajectory[0].positions[1], 2.0);
189        assert!(trajectory[0].velocities.is_none());
190        assert_approx_eq!(
191            trajectory[0].time_from_start.as_secs_f64(),
192            expected_duration_secs
193        );
194    }
195
196    #[test]
197    fn send_joint_positions_none_limited() {
198        test_send_joint_positions(vec![1.0, 2.0], 4.0);
199    }
200
201    #[test]
202    fn send_joint_positions_limited() {
203        // joint0 is over limit
204        test_send_joint_positions(vec![0.1, 2.0], 10.0);
205        // joint1 is over limit
206        test_send_joint_positions(vec![1.0, 0.2], 10.0);
207        // joint0/1 are over limit, joint0 is dominant
208        test_send_joint_positions(vec![0.1, 0.6], 10.0);
209        // joint0/1 are over limit, joint1 is dominant
210        test_send_joint_positions(vec![0.3, 0.2], 10.0);
211    }
212
213    fn test_send_joint_trajectory(limits: Vec<f64>, expected_durations_secs: [f64; 2]) {
214        let client = Arc::new(DummyJointTrajectoryClient::new(vec![
215            "a".to_owned(),
216            "b".to_owned(),
217        ]));
218        let limiter = JointVelocityLimiter::new(client.clone(), limits);
219        assert!(
220            tokio_test::block_on(
221                limiter
222                    .send_joint_trajectory(vec![
223                        TrajectoryPoint {
224                            positions: vec![1.0, 2.0],
225                            velocities: Some(vec![3.0, 4.0]),
226                            time_from_start: std::time::Duration::from_secs_f64(4.0)
227                        },
228                        TrajectoryPoint {
229                            positions: vec![3.0, 6.0],
230                            velocities: Some(vec![3.0, 4.0]),
231                            time_from_start: std::time::Duration::from_secs_f64(8.0)
232                        }
233                    ])
234                    .unwrap()
235            )
236            .is_ok()
237        );
238        let joint_positions = limiter.current_joint_positions().unwrap();
239        assert_eq!(joint_positions.len(), 2);
240        assert_approx_eq!(joint_positions[0], 3.0);
241        assert_approx_eq!(joint_positions[1], 6.0);
242
243        let trajectory = client.last_trajectory.lock().unwrap();
244        assert_eq!(trajectory.len(), 2);
245        assert_eq!(trajectory[0].positions.len(), 2);
246        assert_approx_eq!(trajectory[0].positions[0], 1.0);
247        assert_approx_eq!(trajectory[0].positions[1], 2.0);
248        assert!(trajectory[0].velocities.is_some());
249        assert_approx_eq!(trajectory[0].velocities.as_ref().unwrap()[0], 3.0);
250        assert_approx_eq!(trajectory[0].velocities.as_ref().unwrap()[1], 4.0);
251
252        assert_eq!(trajectory[1].positions.len(), 2);
253        assert_approx_eq!(trajectory[1].positions[0], 3.0);
254        assert_approx_eq!(trajectory[1].positions[1], 6.0);
255        assert!(trajectory[1].velocities.is_some());
256        assert_approx_eq!(trajectory[1].velocities.as_ref().unwrap()[0], 3.0);
257        assert_approx_eq!(trajectory[1].velocities.as_ref().unwrap()[1], 4.0);
258
259        assert_approx_eq!(
260            trajectory[0].time_from_start.as_secs_f64(),
261            expected_durations_secs[0]
262        );
263
264        assert_approx_eq!(
265            trajectory[1].time_from_start.as_secs_f64(),
266            expected_durations_secs[1]
267        );
268    }
269
270    #[test]
271    fn send_joint_trajectory_none_limited() {
272        test_send_joint_trajectory(vec![1.0, 2.0], [4.0, 8.0]);
273    }
274
275    #[test]
276    fn send_joint_trajectory_limited() {
277        // joint0 is over limit
278        test_send_joint_trajectory(vec![0.1, 2.0], [10.0, 30.0]);
279        // joint1 is over limit
280        test_send_joint_trajectory(vec![1.0, 0.2], [10.0, 30.0]);
281        // joint0/1 are over limit, joint0 is dominant
282        test_send_joint_trajectory(vec![0.1, 0.6], [10.0, 30.0]);
283        // joint0/1 are over limit, joint1 is dominant
284        test_send_joint_trajectory(vec![0.3, 0.2], [10.0, 30.0]);
285        // joint0 / point1 is over limit
286        test_send_joint_trajectory(vec![0.3, 2.0], [4.0, 4.0 + 2.0 / 0.3]);
287        // joint1 / point1 is over limit
288        test_send_joint_trajectory(vec![1.0, 0.8], [4.0, 4.0 + 4.0 / 0.8]);
289    }
290
291    #[test]
292    fn from_urdf() {
293        let s = r#"
294            <robot name="robot">
295                <joint name="a" type="revolute">
296                    <origin xyz="0.0 0.0 0.0" />
297                    <parent link="b" />
298                    <child link="c" />
299                    <axis xyz="0 1 0" />
300                    <limit lower="-2" upper="1.0" effort="0" velocity="1.0"/>
301                </joint>
302            </robot>
303        "#;
304        let urdf_robot = urdf_rs::read_from_string(s).unwrap();
305        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
306        let limiter = JointVelocityLimiter::from_urdf(client, &urdf_robot.joints).unwrap();
307        assert_approx_eq!(limiter.velocity_limits[0], 1.0);
308
309        // joint name mismatch
310        let urdf_robot = urdf_rs::read_from_string(s).unwrap();
311        let client = DummyJointTrajectoryClient::new(vec!["unknown".to_owned()]);
312        let e = JointVelocityLimiter::from_urdf(client, &urdf_robot.joints)
313            .err()
314            .unwrap();
315        assert!(matches!(e, Error::NoJoint(..)));
316    }
317}