arci/clients/
joint_position_limiter.rs

1use std::{f64, ops::RangeInclusive};
2
3use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use tracing::{debug, warn};
6use urdf_rs::JointType;
7
8use crate::{
9    error::Error,
10    traits::{JointTrajectoryClient, TrajectoryPoint},
11    waits::WaitFuture,
12};
13
14#[derive(Debug, Clone)]
15pub struct JointPositionLimiter<C>
16where
17    C: JointTrajectoryClient,
18{
19    client: C,
20    limits: Vec<JointPositionLimit>,
21    strategy: JointPositionLimiterStrategy,
22}
23
24impl<C> JointPositionLimiter<C>
25where
26    C: JointTrajectoryClient,
27{
28    /// Creates a new `JointPositionLimiter` with the given position limits.
29    ///
30    /// # Panics
31    ///
32    /// Panics if the lengths of `limits` and joints that `client` handles are different.
33    #[track_caller]
34    pub fn new(client: C, limits: Vec<JointPositionLimit>) -> Self {
35        Self::new_with_strategy(client, limits, Default::default())
36    }
37
38    #[track_caller]
39    pub fn new_with_strategy(
40        client: C,
41        limits: Vec<JointPositionLimit>,
42        strategy: JointPositionLimiterStrategy,
43    ) -> Self {
44        assert!(client.joint_names().len() == limits.len());
45        Self {
46            client,
47            limits,
48            strategy,
49        }
50    }
51
52    /// Creates a new `JointPositionLimiter` with the position limits defined in URDF.
53    pub fn from_urdf(client: C, joints: &[urdf_rs::Joint]) -> Result<Self, Error> {
54        Self::from_urdf_with_strategy(client, joints, Default::default())
55    }
56
57    pub fn from_urdf_with_strategy(
58        client: C,
59        joints: &[urdf_rs::Joint],
60        strategy: JointPositionLimiterStrategy,
61    ) -> Result<Self, Error> {
62        let mut limits = Vec::new();
63        for joint_name in client.joint_names() {
64            if let Some(i) = joints.iter().position(|j| j.name == *joint_name) {
65                let joint = &joints[i];
66                let limit = if JointType::Continuous == joint.joint_type {
67                    // Continuous joint has no limit.
68                    JointPositionLimit::none()
69                } else {
70                    (joint.limit.lower..=joint.limit.upper).into()
71                };
72                limits.push(limit);
73            } else {
74                return Err(Error::NoJoint(joint_name));
75            }
76        }
77
78        Ok(Self {
79            client,
80            limits,
81            strategy,
82        })
83    }
84
85    pub fn set_strategy(&mut self, strategy: JointPositionLimiterStrategy) {
86        self.strategy = strategy;
87    }
88
89    fn check_joint_position(&self, positions: &mut Vec<f64>) -> Result<(), Error> {
90        assert_eq!(positions.len(), self.limits.len());
91        for (i, limit, position) in self
92            .limits
93            .iter()
94            .zip(positions)
95            .enumerate()
96            .filter_map(|(i, (l, p))| l.range().map(|l| (i, l, p)))
97        {
98            if limit.contains(position) {
99                continue;
100            }
101            match self.strategy {
102                JointPositionLimiterStrategy::Error => {
103                    return Err(Error::OutOfLimit {
104                        name: self.client.joint_names()[i].clone(),
105                        position: *position,
106                        limit,
107                    });
108                }
109                JointPositionLimiterStrategy::Clamp => {
110                    debug!(
111                        "Out of limit: joint={}, position={position}, limit={limit:?}",
112                        self.client.joint_names()[i],
113                    );
114                }
115                JointPositionLimiterStrategy::ClampWithWarn => {
116                    warn!(
117                        "Out of limit: joint={}, position={position}, limit={limit:?}",
118                        self.client.joint_names()[i],
119                    );
120                }
121            }
122            *position = position.clamp(*limit.start(), *limit.end());
123        }
124        Ok(())
125    }
126}
127
128impl<C> JointTrajectoryClient for JointPositionLimiter<C>
129where
130    C: JointTrajectoryClient,
131{
132    fn joint_names(&self) -> Vec<String> {
133        self.client.joint_names()
134    }
135
136    fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
137        self.client.current_joint_positions()
138    }
139
140    fn send_joint_positions(
141        &self,
142        mut positions: Vec<f64>,
143        duration: std::time::Duration,
144    ) -> Result<WaitFuture, Error> {
145        self.check_joint_position(&mut positions)?;
146        self.client.send_joint_positions(positions, duration)
147    }
148
149    fn send_joint_trajectory(
150        &self,
151        mut trajectory: Vec<TrajectoryPoint>,
152    ) -> Result<WaitFuture, Error> {
153        for tp in &mut trajectory {
154            self.check_joint_position(&mut tp.positions)?;
155        }
156        self.client.send_joint_trajectory(trajectory)
157    }
158}
159
160#[derive(Debug, Clone, Copy, Default)]
161pub struct JointPositionLimit(Option<JointPositionLimitInner>);
162
163#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
164#[serde(deny_unknown_fields)]
165struct JointPositionLimitInner {
166    lower: f64,
167    upper: f64,
168}
169
170impl JointPositionLimit {
171    pub fn new(lower: f64, upper: f64) -> Self {
172        Self(Some(JointPositionLimitInner { lower, upper }))
173    }
174
175    pub fn none() -> Self {
176        Self::default()
177    }
178
179    pub fn is_none(&self) -> bool {
180        self.0.is_none()
181    }
182
183    pub fn range(&self) -> Option<RangeInclusive<f64>> {
184        self.0.map(|l| l.lower..=l.upper)
185    }
186
187    pub fn lower(&self) -> Option<f64> {
188        self.0.map(|l| l.lower)
189    }
190
191    pub fn upper(&self) -> Option<f64> {
192        self.0.map(|l| l.upper)
193    }
194}
195
196impl From<RangeInclusive<f64>> for JointPositionLimit {
197    fn from(r: RangeInclusive<f64>) -> Self {
198        Self::new(*r.start(), *r.end())
199    }
200}
201
202impl From<urdf_rs::JointLimit> for JointPositionLimit {
203    fn from(l: urdf_rs::JointLimit) -> Self {
204        Self::new(l.lower, l.upper)
205    }
206}
207
208impl<'de> Deserialize<'de> for JointPositionLimit {
209    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
210    where
211        D: Deserializer<'de>,
212    {
213        #[derive(Deserialize)]
214        #[serde(untagged, deny_unknown_fields)]
215        enum De {
216            Limit(JointPositionLimitInner),
217            Empty {},
218        }
219        Ok(match De::deserialize(deserializer)? {
220            De::Empty {} => Self::none(),
221            De::Limit(l) => Self(Some(l)),
222        })
223    }
224}
225
226impl Serialize for JointPositionLimit {
227    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
228    where
229        S: Serializer,
230    {
231        #[derive(Serialize)]
232        #[serde(untagged)]
233        enum Ser {
234            Limit(JointPositionLimitInner),
235            Empty {},
236        }
237
238        match self.0 {
239            None => Ser::Empty {}.serialize(serializer),
240            Some(l) => Ser::Limit(l).serialize(serializer),
241        }
242    }
243}
244
245impl JsonSchema for JointPositionLimit {
246    fn schema_name() -> String {
247        "JointPositionLimit".into()
248    }
249
250    fn json_schema(gen: &mut SchemaGenerator) -> Schema {
251        // As workaround for https://github.com/tamasfe/taplo/issues/57,
252        // use struct with option value instead of enum.
253        #[allow(dead_code)]
254        #[derive(JsonSchema)]
255        struct JointPositionLimitRepr {
256            lower: Option<f64>,
257            upper: Option<f64>,
258        }
259
260        JointPositionLimitRepr::json_schema(gen)
261    }
262}
263#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, JsonSchema, Serialize, Deserialize)]
264#[non_exhaustive]
265pub enum JointPositionLimiterStrategy {
266    /// If the position is out of the limit, handle it as the same value as the limit.
267    #[default]
268    Clamp,
269    /// If the position is out of the limit, handle it as the same value as the limit with warning.
270    ClampWithWarn,
271    /// If the position is out of the limit, return an error.
272    Error,
273}
274
275#[cfg(test)]
276mod tests {
277    use std::time::Duration;
278
279    use assert_approx_eq::assert_approx_eq;
280
281    use super::*;
282    use crate::DummyJointTrajectoryClient;
283
284    const SECOND: Duration = Duration::from_secs(1);
285
286    #[test]
287    #[should_panic]
288    fn mismatch_size() {
289        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
290        JointPositionLimiter::new(client, vec![(1.0..=2.0).into(), (2.0..=3.0).into()]);
291    }
292
293    #[test]
294    fn joint_names() {
295        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned(), "b".to_owned()]);
296        let limiter =
297            JointPositionLimiter::new(client, vec![(1.0..=2.0).into(), (2.0..=3.0).into()]);
298        let joint_names = limiter.joint_names();
299        assert_eq!(joint_names.len(), 2);
300        assert_eq!(joint_names[0], "a");
301        assert_eq!(joint_names[1], "b");
302    }
303
304    #[tokio::test]
305    async fn send_joint_positions_none_limited() {
306        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
307        let mut client = JointPositionLimiter::new(client, vec![(1.0..=2.0).into()]);
308
309        for strategy in [
310            JointPositionLimiterStrategy::Clamp,
311            JointPositionLimiterStrategy::ClampWithWarn,
312            JointPositionLimiterStrategy::Error,
313        ] {
314            client.set_strategy(strategy);
315
316            client
317                .send_joint_positions(vec![1.0], SECOND)
318                .unwrap()
319                .await
320                .unwrap();
321            assert_approx_eq!(client.current_joint_positions().unwrap()[0], 1.0);
322
323            client
324                .send_joint_positions(vec![2.0], SECOND)
325                .unwrap()
326                .await
327                .unwrap();
328            assert_approx_eq!(client.current_joint_positions().unwrap()[0], 2.0);
329        }
330    }
331
332    #[tokio::test]
333    async fn send_joint_positions_limited_rounded() {
334        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
335        let client = JointPositionLimiter::new(client, vec![(1.0..=2.0).into()]);
336
337        client
338            .send_joint_positions(vec![0.0], SECOND)
339            .unwrap()
340            .await
341            .unwrap();
342        assert_approx_eq!(client.current_joint_positions().unwrap()[0], 1.0);
343
344        client
345            .send_joint_positions(vec![3.0], SECOND)
346            .unwrap()
347            .await
348            .unwrap();
349        assert_approx_eq!(client.current_joint_positions().unwrap()[0], 2.0);
350    }
351
352    #[tokio::test]
353    async fn send_joint_positions_limited_error() {
354        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
355        let client = JointPositionLimiter::new_with_strategy(
356            client,
357            vec![(1.0..=2.0).into()],
358            JointPositionLimiterStrategy::Error,
359        );
360
361        let e = client
362            .send_joint_positions(vec![0.0], SECOND)
363            .err()
364            .unwrap();
365        assert_error(e, 0.0);
366
367        let e = client
368            .send_joint_positions(vec![3.0], SECOND)
369            .err()
370            .unwrap();
371        assert_error(e, 3.0);
372    }
373
374    #[tokio::test]
375    async fn send_joint_trajectory_none_limited() {
376        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
377        let mut client = JointPositionLimiter::new(client, vec![(1.0..=2.0).into()]);
378
379        for strategy in [
380            JointPositionLimiterStrategy::Clamp,
381            JointPositionLimiterStrategy::ClampWithWarn,
382            JointPositionLimiterStrategy::Error,
383        ] {
384            client.set_strategy(strategy);
385
386            client
387                .send_joint_trajectory(vec![
388                    TrajectoryPoint::new(vec![1.0], SECOND * 2),
389                    TrajectoryPoint::new(vec![1.5], SECOND * 3),
390                ])
391                .unwrap()
392                .await
393                .unwrap();
394            assert_approx_eq!(client.current_joint_positions().unwrap()[0], 1.5);
395
396            client
397                .send_joint_trajectory(vec![
398                    TrajectoryPoint::new(vec![1.7], SECOND * 2),
399                    TrajectoryPoint::new(vec![2.0], SECOND * 3),
400                ])
401                .unwrap()
402                .await
403                .unwrap();
404            assert_approx_eq!(client.current_joint_positions().unwrap()[0], 2.0);
405        }
406    }
407
408    #[tokio::test]
409    async fn send_joint_trajectory_limited_rounded() {
410        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
411        let client = JointPositionLimiter::new(client, vec![(1.0..=2.0).into()]);
412
413        client
414            .send_joint_trajectory(vec![
415                TrajectoryPoint::new(vec![0.0], SECOND * 2),
416                TrajectoryPoint::new(vec![0.5], SECOND * 3),
417            ])
418            .unwrap()
419            .await
420            .unwrap();
421        assert_approx_eq!(client.current_joint_positions().unwrap()[0], 1.0);
422
423        client
424            .send_joint_trajectory(vec![
425                TrajectoryPoint::new(vec![2.5], SECOND * 2),
426                TrajectoryPoint::new(vec![3.0], SECOND * 3),
427            ])
428            .unwrap()
429            .await
430            .unwrap();
431        assert_approx_eq!(client.current_joint_positions().unwrap()[0], 2.0);
432    }
433
434    #[tokio::test]
435    async fn send_joint_trajectory_limited_error() {
436        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
437        let client = JointPositionLimiter::new_with_strategy(
438            client,
439            vec![(1.0..=2.0).into()],
440            JointPositionLimiterStrategy::Error,
441        );
442
443        let e = client
444            .send_joint_trajectory(vec![
445                TrajectoryPoint::new(vec![0.0], SECOND * 2),
446                TrajectoryPoint::new(vec![0.5], SECOND * 3),
447            ])
448            .err()
449            .unwrap();
450        assert_error(e, 0.0);
451
452        let e = client
453            .send_joint_trajectory(vec![
454                TrajectoryPoint::new(vec![1.0], SECOND * 2),
455                TrajectoryPoint::new(vec![0.5], SECOND * 3),
456            ])
457            .err()
458            .unwrap();
459        assert_error(e, 0.5);
460
461        let e = client
462            .send_joint_trajectory(vec![
463                TrajectoryPoint::new(vec![2.5], SECOND * 2),
464                TrajectoryPoint::new(vec![3.0], SECOND * 3),
465            ])
466            .err()
467            .unwrap();
468        assert_error(e, 2.5);
469
470        let e = client
471            .send_joint_trajectory(vec![
472                TrajectoryPoint::new(vec![2.0], SECOND * 2),
473                TrajectoryPoint::new(vec![3.0], SECOND * 3),
474            ])
475            .err()
476            .unwrap();
477        assert_error(e, 3.0);
478    }
479
480    fn assert_error(e: Error, position: f64) {
481        match e {
482            Error::OutOfLimit { position: p, .. } => assert_approx_eq!(p, position),
483            _ => panic!("{e:?}"),
484        }
485    }
486
487    #[test]
488    fn from_urdf() {
489        let s = r#"
490            <robot name="robot">
491                <joint name="a" type="revolute">
492                    <origin xyz="0.0 0.0 0.0" />
493                    <parent link="b" />
494                    <child link="c" />
495                    <axis xyz="0 1 0" />
496                    <limit lower="-2" upper="1.0" effort="0" velocity="1.0"/>
497                </joint>
498            </robot>
499        "#;
500        let urdf_robot = urdf_rs::read_from_string(s).unwrap();
501        let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
502        let limiter = JointPositionLimiter::from_urdf(client, &urdf_robot.joints).unwrap();
503        assert_approx_eq!(limiter.limits[0].lower().unwrap(), -2.0);
504        assert_approx_eq!(limiter.limits[0].upper().unwrap(), 1.0);
505
506        // joint name mismatch
507        let urdf_robot = urdf_rs::read_from_string(s).unwrap();
508        let client = DummyJointTrajectoryClient::new(vec!["unknown".to_owned()]);
509        let e = JointPositionLimiter::from_urdf(client, &urdf_robot.joints)
510            .err()
511            .unwrap();
512        assert!(matches!(e, Error::NoJoint(..)));
513    }
514
515    #[test]
516    fn serde_joint_position_limit() {
517        #[derive(Serialize, Deserialize)]
518        struct T {
519            limits: Vec<JointPositionLimit>,
520        }
521
522        let l: T = toml::from_str("limits = [{ lower = 0.0, upper = 1.0 }]").unwrap();
523        assert_approx_eq!(l.limits[0].lower().unwrap(), 0.0);
524        assert_approx_eq!(l.limits[0].upper().unwrap(), 1.0);
525
526        let l: T = toml::from_str("limits = [{}]").unwrap();
527        assert!(l.limits[0].is_none());
528
529        let l: T = toml::from_str(
530            "limits = [\
531                { lower = 0.0, upper = 1.0 },\
532                {},\
533                { lower = 1.0, upper = 2.0 }\
534            ]",
535        )
536        .unwrap();
537        assert_approx_eq!(l.limits[0].lower().unwrap(), 0.0);
538        assert_approx_eq!(l.limits[0].upper().unwrap(), 1.0);
539        assert!(l.limits[1].is_none());
540        assert_approx_eq!(l.limits[2].lower().unwrap(), 1.0);
541        assert_approx_eq!(l.limits[2].upper().unwrap(), 2.0);
542
543        // TODO: We want to serialize to inline table: https://github.com/alexcrichton/toml-rs/issues/265
544        assert_eq!(
545            toml::to_string(&l).unwrap(),
546            "[[limits]]\n\
547             lower = 0.0\n\
548             upper = 1.0\n\
549             \n\
550             [[limits]]\n\
551             \n\
552             [[limits]]\n\
553             lower = 1.0\n\
554             upper = 2.0\n\
555            "
556        );
557    }
558}