1use std::{f64, ops::RangeInclusive};
2
3use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use tracing::{debug, warn};
6use urdf_rs::JointType;
7
8use crate::{
9 error::Error,
10 traits::{JointTrajectoryClient, TrajectoryPoint},
11 waits::WaitFuture,
12};
13
14#[derive(Debug, Clone)]
15pub struct JointPositionLimiter<C>
16where
17 C: JointTrajectoryClient,
18{
19 client: C,
20 limits: Vec<JointPositionLimit>,
21 strategy: JointPositionLimiterStrategy,
22}
23
24impl<C> JointPositionLimiter<C>
25where
26 C: JointTrajectoryClient,
27{
28 #[track_caller]
34 pub fn new(client: C, limits: Vec<JointPositionLimit>) -> Self {
35 Self::new_with_strategy(client, limits, Default::default())
36 }
37
38 #[track_caller]
39 pub fn new_with_strategy(
40 client: C,
41 limits: Vec<JointPositionLimit>,
42 strategy: JointPositionLimiterStrategy,
43 ) -> Self {
44 assert!(client.joint_names().len() == limits.len());
45 Self {
46 client,
47 limits,
48 strategy,
49 }
50 }
51
52 pub fn from_urdf(client: C, joints: &[urdf_rs::Joint]) -> Result<Self, Error> {
54 Self::from_urdf_with_strategy(client, joints, Default::default())
55 }
56
57 pub fn from_urdf_with_strategy(
58 client: C,
59 joints: &[urdf_rs::Joint],
60 strategy: JointPositionLimiterStrategy,
61 ) -> Result<Self, Error> {
62 let mut limits = Vec::new();
63 for joint_name in client.joint_names() {
64 if let Some(i) = joints.iter().position(|j| j.name == *joint_name) {
65 let joint = &joints[i];
66 let limit = if JointType::Continuous == joint.joint_type {
67 JointPositionLimit::none()
69 } else {
70 (joint.limit.lower..=joint.limit.upper).into()
71 };
72 limits.push(limit);
73 } else {
74 return Err(Error::NoJoint(joint_name));
75 }
76 }
77
78 Ok(Self {
79 client,
80 limits,
81 strategy,
82 })
83 }
84
85 pub fn set_strategy(&mut self, strategy: JointPositionLimiterStrategy) {
86 self.strategy = strategy;
87 }
88
89 fn check_joint_position(&self, positions: &mut Vec<f64>) -> Result<(), Error> {
90 assert_eq!(positions.len(), self.limits.len());
91 for (i, limit, position) in self
92 .limits
93 .iter()
94 .zip(positions)
95 .enumerate()
96 .filter_map(|(i, (l, p))| l.range().map(|l| (i, l, p)))
97 {
98 if limit.contains(position) {
99 continue;
100 }
101 match self.strategy {
102 JointPositionLimiterStrategy::Error => {
103 return Err(Error::OutOfLimit {
104 name: self.client.joint_names()[i].clone(),
105 position: *position,
106 limit,
107 });
108 }
109 JointPositionLimiterStrategy::Clamp => {
110 debug!(
111 "Out of limit: joint={}, position={position}, limit={limit:?}",
112 self.client.joint_names()[i],
113 );
114 }
115 JointPositionLimiterStrategy::ClampWithWarn => {
116 warn!(
117 "Out of limit: joint={}, position={position}, limit={limit:?}",
118 self.client.joint_names()[i],
119 );
120 }
121 }
122 *position = position.clamp(*limit.start(), *limit.end());
123 }
124 Ok(())
125 }
126}
127
128impl<C> JointTrajectoryClient for JointPositionLimiter<C>
129where
130 C: JointTrajectoryClient,
131{
132 fn joint_names(&self) -> Vec<String> {
133 self.client.joint_names()
134 }
135
136 fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
137 self.client.current_joint_positions()
138 }
139
140 fn send_joint_positions(
141 &self,
142 mut positions: Vec<f64>,
143 duration: std::time::Duration,
144 ) -> Result<WaitFuture, Error> {
145 self.check_joint_position(&mut positions)?;
146 self.client.send_joint_positions(positions, duration)
147 }
148
149 fn send_joint_trajectory(
150 &self,
151 mut trajectory: Vec<TrajectoryPoint>,
152 ) -> Result<WaitFuture, Error> {
153 for tp in &mut trajectory {
154 self.check_joint_position(&mut tp.positions)?;
155 }
156 self.client.send_joint_trajectory(trajectory)
157 }
158}
159
160#[derive(Debug, Clone, Copy, Default)]
161pub struct JointPositionLimit(Option<JointPositionLimitInner>);
162
163#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
164#[serde(deny_unknown_fields)]
165struct JointPositionLimitInner {
166 lower: f64,
167 upper: f64,
168}
169
170impl JointPositionLimit {
171 pub fn new(lower: f64, upper: f64) -> Self {
172 Self(Some(JointPositionLimitInner { lower, upper }))
173 }
174
175 pub fn none() -> Self {
176 Self::default()
177 }
178
179 pub fn is_none(&self) -> bool {
180 self.0.is_none()
181 }
182
183 pub fn range(&self) -> Option<RangeInclusive<f64>> {
184 self.0.map(|l| l.lower..=l.upper)
185 }
186
187 pub fn lower(&self) -> Option<f64> {
188 self.0.map(|l| l.lower)
189 }
190
191 pub fn upper(&self) -> Option<f64> {
192 self.0.map(|l| l.upper)
193 }
194}
195
196impl From<RangeInclusive<f64>> for JointPositionLimit {
197 fn from(r: RangeInclusive<f64>) -> Self {
198 Self::new(*r.start(), *r.end())
199 }
200}
201
202impl From<urdf_rs::JointLimit> for JointPositionLimit {
203 fn from(l: urdf_rs::JointLimit) -> Self {
204 Self::new(l.lower, l.upper)
205 }
206}
207
208impl<'de> Deserialize<'de> for JointPositionLimit {
209 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
210 where
211 D: Deserializer<'de>,
212 {
213 #[derive(Deserialize)]
214 #[serde(untagged, deny_unknown_fields)]
215 enum De {
216 Limit(JointPositionLimitInner),
217 Empty {},
218 }
219 Ok(match De::deserialize(deserializer)? {
220 De::Empty {} => Self::none(),
221 De::Limit(l) => Self(Some(l)),
222 })
223 }
224}
225
226impl Serialize for JointPositionLimit {
227 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
228 where
229 S: Serializer,
230 {
231 #[derive(Serialize)]
232 #[serde(untagged)]
233 enum Ser {
234 Limit(JointPositionLimitInner),
235 Empty {},
236 }
237
238 match self.0 {
239 None => Ser::Empty {}.serialize(serializer),
240 Some(l) => Ser::Limit(l).serialize(serializer),
241 }
242 }
243}
244
245impl JsonSchema for JointPositionLimit {
246 fn schema_name() -> String {
247 "JointPositionLimit".into()
248 }
249
250 fn json_schema(gen: &mut SchemaGenerator) -> Schema {
251 #[allow(dead_code)]
254 #[derive(JsonSchema)]
255 struct JointPositionLimitRepr {
256 lower: Option<f64>,
257 upper: Option<f64>,
258 }
259
260 JointPositionLimitRepr::json_schema(gen)
261 }
262}
263#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, JsonSchema, Serialize, Deserialize)]
264#[non_exhaustive]
265pub enum JointPositionLimiterStrategy {
266 #[default]
268 Clamp,
269 ClampWithWarn,
271 Error,
273}
274
275#[cfg(test)]
276mod tests {
277 use std::time::Duration;
278
279 use assert_approx_eq::assert_approx_eq;
280
281 use super::*;
282 use crate::DummyJointTrajectoryClient;
283
284 const SECOND: Duration = Duration::from_secs(1);
285
286 #[test]
287 #[should_panic]
288 fn mismatch_size() {
289 let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
290 JointPositionLimiter::new(client, vec![(1.0..=2.0).into(), (2.0..=3.0).into()]);
291 }
292
293 #[test]
294 fn joint_names() {
295 let client = DummyJointTrajectoryClient::new(vec!["a".to_owned(), "b".to_owned()]);
296 let limiter =
297 JointPositionLimiter::new(client, vec![(1.0..=2.0).into(), (2.0..=3.0).into()]);
298 let joint_names = limiter.joint_names();
299 assert_eq!(joint_names.len(), 2);
300 assert_eq!(joint_names[0], "a");
301 assert_eq!(joint_names[1], "b");
302 }
303
304 #[tokio::test]
305 async fn send_joint_positions_none_limited() {
306 let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
307 let mut client = JointPositionLimiter::new(client, vec![(1.0..=2.0).into()]);
308
309 for strategy in [
310 JointPositionLimiterStrategy::Clamp,
311 JointPositionLimiterStrategy::ClampWithWarn,
312 JointPositionLimiterStrategy::Error,
313 ] {
314 client.set_strategy(strategy);
315
316 client
317 .send_joint_positions(vec![1.0], SECOND)
318 .unwrap()
319 .await
320 .unwrap();
321 assert_approx_eq!(client.current_joint_positions().unwrap()[0], 1.0);
322
323 client
324 .send_joint_positions(vec![2.0], SECOND)
325 .unwrap()
326 .await
327 .unwrap();
328 assert_approx_eq!(client.current_joint_positions().unwrap()[0], 2.0);
329 }
330 }
331
332 #[tokio::test]
333 async fn send_joint_positions_limited_rounded() {
334 let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
335 let client = JointPositionLimiter::new(client, vec![(1.0..=2.0).into()]);
336
337 client
338 .send_joint_positions(vec![0.0], SECOND)
339 .unwrap()
340 .await
341 .unwrap();
342 assert_approx_eq!(client.current_joint_positions().unwrap()[0], 1.0);
343
344 client
345 .send_joint_positions(vec![3.0], SECOND)
346 .unwrap()
347 .await
348 .unwrap();
349 assert_approx_eq!(client.current_joint_positions().unwrap()[0], 2.0);
350 }
351
352 #[tokio::test]
353 async fn send_joint_positions_limited_error() {
354 let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
355 let client = JointPositionLimiter::new_with_strategy(
356 client,
357 vec![(1.0..=2.0).into()],
358 JointPositionLimiterStrategy::Error,
359 );
360
361 let e = client
362 .send_joint_positions(vec![0.0], SECOND)
363 .err()
364 .unwrap();
365 assert_error(e, 0.0);
366
367 let e = client
368 .send_joint_positions(vec![3.0], SECOND)
369 .err()
370 .unwrap();
371 assert_error(e, 3.0);
372 }
373
374 #[tokio::test]
375 async fn send_joint_trajectory_none_limited() {
376 let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
377 let mut client = JointPositionLimiter::new(client, vec![(1.0..=2.0).into()]);
378
379 for strategy in [
380 JointPositionLimiterStrategy::Clamp,
381 JointPositionLimiterStrategy::ClampWithWarn,
382 JointPositionLimiterStrategy::Error,
383 ] {
384 client.set_strategy(strategy);
385
386 client
387 .send_joint_trajectory(vec![
388 TrajectoryPoint::new(vec![1.0], SECOND * 2),
389 TrajectoryPoint::new(vec![1.5], SECOND * 3),
390 ])
391 .unwrap()
392 .await
393 .unwrap();
394 assert_approx_eq!(client.current_joint_positions().unwrap()[0], 1.5);
395
396 client
397 .send_joint_trajectory(vec![
398 TrajectoryPoint::new(vec![1.7], SECOND * 2),
399 TrajectoryPoint::new(vec![2.0], SECOND * 3),
400 ])
401 .unwrap()
402 .await
403 .unwrap();
404 assert_approx_eq!(client.current_joint_positions().unwrap()[0], 2.0);
405 }
406 }
407
408 #[tokio::test]
409 async fn send_joint_trajectory_limited_rounded() {
410 let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
411 let client = JointPositionLimiter::new(client, vec![(1.0..=2.0).into()]);
412
413 client
414 .send_joint_trajectory(vec![
415 TrajectoryPoint::new(vec![0.0], SECOND * 2),
416 TrajectoryPoint::new(vec![0.5], SECOND * 3),
417 ])
418 .unwrap()
419 .await
420 .unwrap();
421 assert_approx_eq!(client.current_joint_positions().unwrap()[0], 1.0);
422
423 client
424 .send_joint_trajectory(vec![
425 TrajectoryPoint::new(vec![2.5], SECOND * 2),
426 TrajectoryPoint::new(vec![3.0], SECOND * 3),
427 ])
428 .unwrap()
429 .await
430 .unwrap();
431 assert_approx_eq!(client.current_joint_positions().unwrap()[0], 2.0);
432 }
433
434 #[tokio::test]
435 async fn send_joint_trajectory_limited_error() {
436 let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
437 let client = JointPositionLimiter::new_with_strategy(
438 client,
439 vec![(1.0..=2.0).into()],
440 JointPositionLimiterStrategy::Error,
441 );
442
443 let e = client
444 .send_joint_trajectory(vec![
445 TrajectoryPoint::new(vec![0.0], SECOND * 2),
446 TrajectoryPoint::new(vec![0.5], SECOND * 3),
447 ])
448 .err()
449 .unwrap();
450 assert_error(e, 0.0);
451
452 let e = client
453 .send_joint_trajectory(vec![
454 TrajectoryPoint::new(vec![1.0], SECOND * 2),
455 TrajectoryPoint::new(vec![0.5], SECOND * 3),
456 ])
457 .err()
458 .unwrap();
459 assert_error(e, 0.5);
460
461 let e = client
462 .send_joint_trajectory(vec![
463 TrajectoryPoint::new(vec![2.5], SECOND * 2),
464 TrajectoryPoint::new(vec![3.0], SECOND * 3),
465 ])
466 .err()
467 .unwrap();
468 assert_error(e, 2.5);
469
470 let e = client
471 .send_joint_trajectory(vec![
472 TrajectoryPoint::new(vec![2.0], SECOND * 2),
473 TrajectoryPoint::new(vec![3.0], SECOND * 3),
474 ])
475 .err()
476 .unwrap();
477 assert_error(e, 3.0);
478 }
479
480 fn assert_error(e: Error, position: f64) {
481 match e {
482 Error::OutOfLimit { position: p, .. } => assert_approx_eq!(p, position),
483 _ => panic!("{e:?}"),
484 }
485 }
486
487 #[test]
488 fn from_urdf() {
489 let s = r#"
490 <robot name="robot">
491 <joint name="a" type="revolute">
492 <origin xyz="0.0 0.0 0.0" />
493 <parent link="b" />
494 <child link="c" />
495 <axis xyz="0 1 0" />
496 <limit lower="-2" upper="1.0" effort="0" velocity="1.0"/>
497 </joint>
498 </robot>
499 "#;
500 let urdf_robot = urdf_rs::read_from_string(s).unwrap();
501 let client = DummyJointTrajectoryClient::new(vec!["a".to_owned()]);
502 let limiter = JointPositionLimiter::from_urdf(client, &urdf_robot.joints).unwrap();
503 assert_approx_eq!(limiter.limits[0].lower().unwrap(), -2.0);
504 assert_approx_eq!(limiter.limits[0].upper().unwrap(), 1.0);
505
506 let urdf_robot = urdf_rs::read_from_string(s).unwrap();
508 let client = DummyJointTrajectoryClient::new(vec!["unknown".to_owned()]);
509 let e = JointPositionLimiter::from_urdf(client, &urdf_robot.joints)
510 .err()
511 .unwrap();
512 assert!(matches!(e, Error::NoJoint(..)));
513 }
514
515 #[test]
516 fn serde_joint_position_limit() {
517 #[derive(Serialize, Deserialize)]
518 struct T {
519 limits: Vec<JointPositionLimit>,
520 }
521
522 let l: T = toml::from_str("limits = [{ lower = 0.0, upper = 1.0 }]").unwrap();
523 assert_approx_eq!(l.limits[0].lower().unwrap(), 0.0);
524 assert_approx_eq!(l.limits[0].upper().unwrap(), 1.0);
525
526 let l: T = toml::from_str("limits = [{}]").unwrap();
527 assert!(l.limits[0].is_none());
528
529 let l: T = toml::from_str(
530 "limits = [\
531 { lower = 0.0, upper = 1.0 },\
532 {},\
533 { lower = 1.0, upper = 2.0 }\
534 ]",
535 )
536 .unwrap();
537 assert_approx_eq!(l.limits[0].lower().unwrap(), 0.0);
538 assert_approx_eq!(l.limits[0].upper().unwrap(), 1.0);
539 assert!(l.limits[1].is_none());
540 assert_approx_eq!(l.limits[2].lower().unwrap(), 1.0);
541 assert_approx_eq!(l.limits[2].upper().unwrap(), 2.0);
542
543 assert_eq!(
545 toml::to_string(&l).unwrap(),
546 "[[limits]]\n\
547 lower = 0.0\n\
548 upper = 1.0\n\
549 \n\
550 [[limits]]\n\
551 \n\
552 [[limits]]\n\
553 lower = 1.0\n\
554 upper = 2.0\n\
555 "
556 );
557 }
558}