arci/clients/
partial_joint_trajectory_client.rs

1use crate::{
2    error::Error,
3    traits::{JointTrajectoryClient, TrajectoryPoint},
4    waits::WaitFuture,
5};
6
7#[derive(Debug)]
8pub struct PartialJointTrajectoryClient<C>
9where
10    C: JointTrajectoryClient,
11{
12    joint_names: Vec<String>,
13    shared_client: C,
14    full_joint_names: Vec<String>,
15}
16
17/// # To copy joint name and position between `from` and `to`
18///
19/// Copy position of same joint name.
20/// This function returns Ok() or Err().
21///
22/// # When this function through Error?
23///
24/// length of joint names and positions is difference.
25///
26/// # Sample code
27///
28/// ```
29/// use arci::copy_joint_positions;
30///
31/// let from_positions = vec![2.1_f64, 4.8, 1.0, 6.5];
32/// let from_joint_names = vec![
33/// String::from("part1"),
34/// String::from("part2"),
35/// String::from("part3"),
36/// String::from("part4"),
37/// ];
38///
39/// let mut to_positions = vec![3.3_f64, 8.1];
40/// let to_joint_names = vec![
41/// String::from("part4"),
42/// String::from("part1"),
43/// ];
44///
45/// copy_joint_positions(
46/// &from_joint_names,
47/// &from_positions,
48/// &to_joint_names,
49/// &mut to_positions,
50/// ).unwrap();
51/// ```
52pub fn copy_joint_positions(
53    from_joint_names: &[String],
54    from_positions: &[f64],
55    to_joint_names: &[String],
56    to_positions: &mut [f64],
57) -> Result<(), Error> {
58    if from_joint_names.len() != from_positions.len() || to_joint_names.len() != to_positions.len()
59    {
60        return Err(Error::CopyJointError(
61            from_joint_names.to_vec(),
62            from_positions.to_vec(),
63            to_joint_names.to_vec(),
64            to_positions.to_vec(),
65        ));
66    }
67    for (to_index, to_joint_name) in to_joint_names.iter().enumerate() {
68        if let Some(from_index) = from_joint_names.iter().position(|x| x == to_joint_name) {
69            to_positions[to_index] = from_positions[from_index];
70        }
71    }
72    Ok(())
73}
74
75impl<C> PartialJointTrajectoryClient<C>
76where
77    C: JointTrajectoryClient,
78{
79    /// # Generate Partial Joint Client
80    ///
81    /// This function check partial joint name and full joint name.
82    /// Only allow unique partial joint name and joint name contained full.
83    ///
84    /// # Important point
85    ///
86    /// Partial Joint name is changed to dictionary order.
87    ///
88    pub fn new(joint_names: Vec<String>, shared_client: C) -> Result<Self, Error> {
89        use std::collections::HashSet;
90
91        // check length between full and partial
92        let full_joint_names = shared_client.joint_names().to_vec();
93        if joint_names.len() > full_joint_names.len() {
94            return Err(Error::LengthMismatch {
95                model: full_joint_names.len(),
96                input: joint_names.len(),
97            });
98        }
99
100        // check redundant joint name of partial
101        let input_len = joint_names.len();
102        let unique_joint_names = joint_names.clone().into_iter().collect::<HashSet<String>>();
103
104        if unique_joint_names.len() != input_len {
105            return Err(Error::LengthMismatch {
106                model: unique_joint_names.len(),
107                input: input_len,
108            });
109        }
110
111        if !unique_joint_names
112            .iter()
113            .all(|joint_name| full_joint_names.iter().any(|x| x == joint_name))
114        {
115            return Err(Error::JointNamesMismatch {
116                partial: joint_names,
117                full: full_joint_names,
118            });
119        }
120
121        Ok(Self {
122            joint_names,
123            shared_client,
124            full_joint_names,
125        })
126    }
127}
128
129impl<C> JointTrajectoryClient for PartialJointTrajectoryClient<C>
130where
131    C: JointTrajectoryClient,
132{
133    fn joint_names(&self) -> Vec<String> {
134        self.joint_names.clone()
135    }
136
137    fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
138        let mut result = vec![0.0; self.joint_names.len()];
139        copy_joint_positions(
140            &self.full_joint_names,
141            &self.shared_client.current_joint_positions()?,
142            &self.joint_names(),
143            &mut result,
144        )?;
145        Ok(result)
146    }
147
148    fn send_joint_positions(
149        &self,
150        positions: Vec<f64>,
151        duration: std::time::Duration,
152    ) -> Result<WaitFuture, Error> {
153        let mut full_positions = self.shared_client.current_joint_positions()?;
154        copy_joint_positions(
155            &self.joint_names(),
156            &positions,
157            &self.full_joint_names,
158            &mut full_positions,
159        )?;
160        self.shared_client
161            .send_joint_positions(full_positions, duration)
162    }
163
164    fn send_joint_trajectory(&self, trajectory: Vec<TrajectoryPoint>) -> Result<WaitFuture, Error> {
165        let full_positions_base = self.shared_client.current_joint_positions()?;
166        let mut full_trajectory = vec![];
167        let full_dof = full_positions_base.len();
168        for point in trajectory {
169            let mut full_positions = full_positions_base.clone();
170            copy_joint_positions(
171                &self.joint_names(),
172                &point.positions,
173                &self.full_joint_names,
174                &mut full_positions,
175            )?;
176            let mut full_point = TrajectoryPoint::new(full_positions, point.time_from_start);
177            if let Some(partial_velocities) = &point.velocities {
178                let mut full_velocities = vec![0.0; full_dof];
179                copy_joint_positions(
180                    &self.joint_names(),
181                    partial_velocities,
182                    &self.full_joint_names,
183                    &mut full_velocities,
184                )?;
185                full_point.velocities = Some(full_velocities);
186            }
187            full_trajectory.push(full_point);
188        }
189        self.shared_client.send_joint_trajectory(full_trajectory)
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use std::sync::{Arc, Mutex};
196
197    use assert_approx_eq::assert_approx_eq;
198
199    use super::*;
200
201    #[derive(Debug, Clone)]
202    struct DummyFull {
203        name: Vec<String>,
204        pos: Arc<Mutex<Vec<f64>>>,
205        last_trajectory: Arc<Mutex<Vec<TrajectoryPoint>>>,
206    }
207    impl JointTrajectoryClient for DummyFull {
208        fn joint_names(&self) -> Vec<String> {
209            self.name.clone()
210        }
211
212        fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
213            Ok(self.pos.lock().unwrap().clone())
214        }
215
216        fn send_joint_positions(
217            &self,
218            positions: Vec<f64>,
219            _duration: std::time::Duration,
220        ) -> Result<WaitFuture, Error> {
221            *self.pos.lock().unwrap() = positions;
222            Ok(WaitFuture::ready())
223        }
224
225        fn send_joint_trajectory(
226            &self,
227            full_trajectory: Vec<TrajectoryPoint>,
228        ) -> Result<WaitFuture, Error> {
229            if let Some(last_point) = full_trajectory.last() {
230                last_point
231                    .positions
232                    .clone_into(&mut self.pos.lock().unwrap());
233            }
234            *self.last_trajectory.lock().unwrap() = full_trajectory;
235            Ok(WaitFuture::ready())
236        }
237    }
238
239    #[test]
240    fn test_partial_new() {
241        let client = DummyFull {
242            name: vec![
243                String::from("part1"),
244                String::from("high"),
245                String::from("part2"),
246            ],
247            pos: Arc::new(Mutex::new(vec![1.0_f64, 3.0_f64])),
248            last_trajectory: Arc::new(Mutex::new(Vec::new())),
249        };
250
251        // over joint pattern(Error)
252        let joint_names = vec![
253            String::from("low"),
254            String::from("part2"),
255            String::from("high"),
256            String::from("part1"),
257        ];
258        let partial = PartialJointTrajectoryClient::new(joint_names, client.clone());
259        assert!(partial.is_err());
260
261        // same joint pattern(Error)
262        let joint_names = vec![
263            String::from("low"),
264            String::from("part2"),
265            String::from("part1"),
266        ];
267        let partial = PartialJointTrajectoryClient::new(joint_names, client.clone());
268        assert!(partial.is_err());
269
270        // same joint pattern(Ok)
271        let joint_names = vec![
272            String::from("high"),
273            String::from("part2"),
274            String::from("part1"),
275        ];
276        let partial = PartialJointTrajectoryClient::new(joint_names, client.clone()).unwrap();
277
278        assert_eq!(
279            format!("{:?}", partial.joint_names),
280            "[\"high\", \"part2\", \"part1\"]"
281        );
282        assert_eq!(
283            format!("{client:?}"),
284            format!("{:?}", partial.shared_client)
285        );
286        assert_eq!(
287            format!("{:?}", partial.full_joint_names),
288            "[\"part1\", \"high\", \"part2\"]"
289        );
290
291        // few joint pattern(Error)
292        let joint_names = vec![String::from("low"), String::from("high")];
293        let partial = PartialJointTrajectoryClient::new(joint_names, client.clone());
294        assert!(partial.is_err());
295
296        // few joint pattern(Ok)
297        let joint_names = vec![String::from("part1"), String::from("high")];
298        let partial = PartialJointTrajectoryClient::new(joint_names, client.clone()).unwrap();
299
300        assert_eq!(
301            format!("{:?}", partial.joint_names),
302            "[\"part1\", \"high\"]"
303        );
304        assert_eq!(
305            format!("{client:?}"),
306            format!("{:?}", partial.shared_client)
307        );
308        assert_eq!(
309            format!("{:?}", partial.full_joint_names),
310            "[\"part1\", \"high\", \"part2\"]"
311        );
312    }
313
314    #[test]
315    fn test_fn_copy_joint_position_for_from_joint() {
316        let from_positions = vec![2.1_f64, 4.8, 1.0, 6.5];
317        let to_joint_names = vec![
318            String::from("part1"),
319            String::from("part2"),
320            String::from("part3"),
321            String::from("part4"),
322        ];
323
324        // random order pattern
325        let mut to_positions = vec![3.3_f64, 8.1, 5.2, 0.8];
326        let from_joint_names = vec![
327            String::from("part4"),
328            String::from("part1"),
329            String::from("part3"),
330            String::from("part2"),
331        ];
332        let correct = [4.8_f64, 6.5, 1.0, 2.1];
333        let result = copy_joint_positions(
334            &from_joint_names,
335            &from_positions,
336            &to_joint_names,
337            &mut to_positions,
338        );
339        assert!(result.is_ok());
340        println!("{to_positions:?}");
341        to_positions
342            .iter()
343            .zip(correct.iter())
344            .for_each(|(pos, correct)| assert_approx_eq!(*pos, correct));
345
346        // few joint pattern(Error)
347        let mut to_positions = vec![3.3_f64, 8.1, 5.2, 0.8];
348        let from_joint_names = vec![String::from("part4"), String::from("part1")];
349        let result = copy_joint_positions(
350            &from_joint_names,
351            &from_positions,
352            &to_joint_names,
353            &mut to_positions,
354        );
355        assert!(result.is_err());
356
357        // few joint pattern(Ok)
358        let from_positions = vec![2.1_f64, 4.8];
359        let mut to_positions = vec![3.3_f64, 8.1, 5.2, 0.8];
360        let from_joint_names = vec![String::from("part4"), String::from("part1")];
361        let correct = [4.8_f64, 8.1, 5.2, 2.1];
362        let result = copy_joint_positions(
363            &from_joint_names,
364            &from_positions,
365            &to_joint_names,
366            &mut to_positions,
367        );
368        assert!(result.is_ok());
369        println!("{to_positions:?}");
370        to_positions
371            .iter()
372            .zip(correct.iter())
373            .for_each(|(pos, correct)| assert_approx_eq!(*pos, correct));
374    }
375
376    #[test]
377    fn test_fn_copy_joint_position_for_to_joint() {
378        let from_joint_names = vec![
379            String::from("part1"),
380            String::from("part2"),
381            String::from("part3"),
382            String::from("part4"),
383        ];
384        let from_positions = vec![2.1_f64, 4.8, 1.0, 6.5];
385        let to_joint_names = vec![
386            String::from("part1"),
387            String::from("part2"),
388            String::from("part3"),
389            String::from("part4"),
390        ];
391
392        // lexical order pattern
393        let mut to_positions = vec![3.3_f64, 8.1, 5.2, 0.8];
394        copy_joint_positions(
395            &from_joint_names,
396            &from_positions,
397            &to_joint_names,
398            &mut to_positions,
399        )
400        .unwrap();
401        to_positions
402            .iter()
403            .zip(from_positions.iter())
404            .for_each(|(pos, correct)| assert_approx_eq!(*pos, correct));
405        println!("{to_positions:?}");
406
407        // random order pattern
408        let mut to_positions = vec![3.3_f64, 8.1, 5.2, 0.8];
409        let to_joint_names = vec![
410            String::from("part4"),
411            String::from("part1"),
412            String::from("part3"),
413            String::from("part2"),
414        ];
415        let correct = [6.5_f64, 2.1, 1.0, 4.8];
416        copy_joint_positions(
417            &from_joint_names,
418            &from_positions,
419            &to_joint_names,
420            &mut to_positions,
421        )
422        .unwrap();
423        to_positions
424            .iter()
425            .zip(correct.iter())
426            .for_each(|(pos, correct)| assert_approx_eq!(*pos, correct));
427        println!("{to_positions:?}");
428
429        // few joint pattern
430        let mut to_positions = vec![3.3_f64, 8.1];
431        let to_joint_names = vec![String::from("part4"), String::from("part1")];
432        let correct = [6.5_f64, 2.1];
433        copy_joint_positions(
434            &from_joint_names,
435            &from_positions,
436            &to_joint_names,
437            &mut to_positions,
438        )
439        .unwrap();
440        to_positions
441            .iter()
442            .zip(correct.iter())
443            .for_each(|(pos, correct)| assert_approx_eq!(*pos, correct));
444        println!("{to_positions:?}");
445    }
446
447    #[test]
448    fn test_partial_joint_name() {
449        let client = DummyFull {
450            name: vec![
451                String::from("part1"),
452                String::from("high"),
453                String::from("part2"),
454                String::from("low"),
455            ],
456            pos: Arc::new(Mutex::new(vec![1.0_f64, 3.0_f64])),
457            last_trajectory: Arc::new(Mutex::new(Vec::new())),
458        };
459        let joint_names = vec![String::from("part1"), String::from("high")];
460        let partial = PartialJointTrajectoryClient::new(joint_names, client).unwrap();
461
462        assert_eq!(
463            format!("{:?}", partial.joint_names()),
464            "[\"part1\", \"high\"]"
465        );
466    }
467
468    #[test]
469    fn test_partial_current_pos() {
470        // partial = full
471        let client = DummyFull {
472            name: vec![
473                String::from("part1"),
474                String::from("high"),
475                String::from("part2"),
476            ],
477            pos: Arc::new(Mutex::new(vec![1.0_f64, 3.0, 2.4])),
478            last_trajectory: Arc::new(Mutex::new(Vec::new())),
479        };
480        let joint_names = vec![
481            String::from("part1"),
482            String::from("high"),
483            String::from("part2"),
484        ];
485        let correct = [1.0_f64, 3.0, 2.4];
486
487        let partial = PartialJointTrajectoryClient::new(joint_names, client.clone()).unwrap();
488        let current_pos = partial.current_joint_positions();
489        assert!(current_pos.is_ok());
490        let current_pos = current_pos.unwrap();
491
492        current_pos
493            .iter()
494            .zip(correct.iter())
495            .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
496
497        // partial < full
498        let joint_names = vec![String::from("part1"), String::from("part2")];
499        let correct = [1.0_f64, 2.4];
500
501        let partial = PartialJointTrajectoryClient::new(joint_names, client).unwrap();
502        let current_pos = partial.current_joint_positions();
503        assert!(current_pos.is_ok());
504        let current_pos = current_pos.unwrap();
505
506        current_pos
507            .iter()
508            .zip(correct.iter())
509            .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
510    }
511
512    #[tokio::test]
513    async fn test_partial_send_pos() {
514        let client = DummyFull {
515            name: vec![
516                String::from("part1"),
517                String::from("high"),
518                String::from("part2"),
519            ],
520            pos: Arc::new(Mutex::new(vec![1.0_f64, 3.0_f64, 5.0])),
521            last_trajectory: Arc::new(Mutex::new(Vec::new())),
522        };
523        let duration = std::time::Duration::from_secs(5);
524
525        // partial = full
526        let joint_names = vec![
527            String::from("part1"),
528            String::from("high"),
529            String::from("part2"),
530        ];
531        let next_pos = vec![2.2_f64, 0.5, 1.7];
532
533        let partial =
534            PartialJointTrajectoryClient::new(joint_names.clone(), client.clone()).unwrap();
535
536        let result = partial.send_joint_positions(next_pos.clone(), duration);
537        assert!(result.is_ok());
538        assert!(result.unwrap().await.is_ok());
539
540        let current_pos = partial.current_joint_positions().unwrap();
541        current_pos
542            .iter()
543            .zip(next_pos.iter())
544            .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
545
546        // partial < full
547        let joint_names = vec![String::from("part2"), String::from("part1")];
548        let next_pos = vec![4.8_f64, 1.5];
549
550        let partial =
551            PartialJointTrajectoryClient::new(joint_names.clone(), client.clone()).unwrap();
552
553        let result = partial.send_joint_positions(next_pos.clone(), duration);
554        assert!(result.is_ok());
555        assert!(result.unwrap().await.is_ok());
556
557        let current_pos = partial.current_joint_positions().unwrap();
558        println!("{current_pos:?}");
559        current_pos
560            .iter()
561            .zip(next_pos.iter())
562            .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
563    }
564
565    #[tokio::test]
566    async fn test_partial_trajectory() {
567        let client = DummyFull {
568            name: vec![
569                String::from("part1"),
570                String::from("high"),
571                String::from("part2"),
572                String::from("low"),
573            ],
574            pos: Arc::new(Mutex::new(vec![5.1_f64, 0.8, 2.4, 4.5])),
575            last_trajectory: Arc::new(Mutex::new(Vec::new())),
576        };
577
578        // partial = full
579        let joint_names = vec![
580            String::from("part1"),
581            String::from("high"),
582            String::from("part2"),
583            String::from("low"),
584        ];
585        let trajectories = vec![
586            TrajectoryPoint::new(
587                vec![1.0_f64, 3.0_f64, 2.2_f64, 4.0_f64],
588                std::time::Duration::from_secs(1),
589            ),
590            TrajectoryPoint::new(
591                vec![3.4_f64, 5.8_f64, 0.1_f64, 2.5_f64],
592                std::time::Duration::from_secs(2),
593            ),
594        ];
595        let correct = [3.4_f64, 5.8_f64, 0.1_f64, 2.5_f64];
596
597        let partial =
598            PartialJointTrajectoryClient::new(joint_names.clone(), client.clone()).unwrap();
599        let result = partial.send_joint_trajectory(trajectories);
600        assert!(result.is_ok());
601        assert!(result.unwrap().await.is_ok());
602
603        let current_pos = partial.current_joint_positions().unwrap();
604        println!("{current_pos:?}");
605        current_pos
606            .iter()
607            .zip(correct.iter())
608            .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
609
610        // partial < full
611        let joint_names = vec![
612            String::from("low"),
613            String::from("part2"),
614            String::from("part1"),
615        ];
616        let trajectories = vec![
617            TrajectoryPoint::new(
618                vec![3.4_f64, 5.8_f64, 0.1_f64],
619                std::time::Duration::from_secs(2),
620            ),
621            TrajectoryPoint::new(
622                vec![1.0_f64, 3.0_f64, 2.2_f64],
623                std::time::Duration::from_secs(1),
624            ),
625        ];
626        let correct = [1.0_f64, 3.0_f64, 2.2_f64];
627
628        let partial =
629            PartialJointTrajectoryClient::new(joint_names.clone(), client.clone()).unwrap();
630        let result = partial.send_joint_trajectory(trajectories);
631        assert!(result.is_ok());
632        assert!(result.unwrap().await.is_ok());
633
634        let current_pos = partial.current_joint_positions().unwrap();
635        println!("{current_pos:?}");
636        current_pos
637            .iter()
638            .zip(correct.iter())
639            .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
640    }
641}