1use crate::{
2 error::Error,
3 traits::{JointTrajectoryClient, TrajectoryPoint},
4 waits::WaitFuture,
5};
6
7#[derive(Debug)]
8pub struct PartialJointTrajectoryClient<C>
9where
10 C: JointTrajectoryClient,
11{
12 joint_names: Vec<String>,
13 shared_client: C,
14 full_joint_names: Vec<String>,
15}
16
17pub fn copy_joint_positions(
53 from_joint_names: &[String],
54 from_positions: &[f64],
55 to_joint_names: &[String],
56 to_positions: &mut [f64],
57) -> Result<(), Error> {
58 if from_joint_names.len() != from_positions.len() || to_joint_names.len() != to_positions.len()
59 {
60 return Err(Error::CopyJointError(
61 from_joint_names.to_vec(),
62 from_positions.to_vec(),
63 to_joint_names.to_vec(),
64 to_positions.to_vec(),
65 ));
66 }
67 for (to_index, to_joint_name) in to_joint_names.iter().enumerate() {
68 if let Some(from_index) = from_joint_names.iter().position(|x| x == to_joint_name) {
69 to_positions[to_index] = from_positions[from_index];
70 }
71 }
72 Ok(())
73}
74
75impl<C> PartialJointTrajectoryClient<C>
76where
77 C: JointTrajectoryClient,
78{
79 pub fn new(joint_names: Vec<String>, shared_client: C) -> Result<Self, Error> {
89 use std::collections::HashSet;
90
91 let full_joint_names = shared_client.joint_names().to_vec();
93 if joint_names.len() > full_joint_names.len() {
94 return Err(Error::LengthMismatch {
95 model: full_joint_names.len(),
96 input: joint_names.len(),
97 });
98 }
99
100 let input_len = joint_names.len();
102 let unique_joint_names = joint_names.clone().into_iter().collect::<HashSet<String>>();
103
104 if unique_joint_names.len() != input_len {
105 return Err(Error::LengthMismatch {
106 model: unique_joint_names.len(),
107 input: input_len,
108 });
109 }
110
111 if !unique_joint_names
112 .iter()
113 .all(|joint_name| full_joint_names.iter().any(|x| x == joint_name))
114 {
115 return Err(Error::JointNamesMismatch {
116 partial: joint_names,
117 full: full_joint_names,
118 });
119 }
120
121 Ok(Self {
122 joint_names,
123 shared_client,
124 full_joint_names,
125 })
126 }
127}
128
129impl<C> JointTrajectoryClient for PartialJointTrajectoryClient<C>
130where
131 C: JointTrajectoryClient,
132{
133 fn joint_names(&self) -> Vec<String> {
134 self.joint_names.clone()
135 }
136
137 fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
138 let mut result = vec![0.0; self.joint_names.len()];
139 copy_joint_positions(
140 &self.full_joint_names,
141 &self.shared_client.current_joint_positions()?,
142 &self.joint_names(),
143 &mut result,
144 )?;
145 Ok(result)
146 }
147
148 fn send_joint_positions(
149 &self,
150 positions: Vec<f64>,
151 duration: std::time::Duration,
152 ) -> Result<WaitFuture, Error> {
153 let mut full_positions = self.shared_client.current_joint_positions()?;
154 copy_joint_positions(
155 &self.joint_names(),
156 &positions,
157 &self.full_joint_names,
158 &mut full_positions,
159 )?;
160 self.shared_client
161 .send_joint_positions(full_positions, duration)
162 }
163
164 fn send_joint_trajectory(&self, trajectory: Vec<TrajectoryPoint>) -> Result<WaitFuture, Error> {
165 let full_positions_base = self.shared_client.current_joint_positions()?;
166 let mut full_trajectory = vec![];
167 let full_dof = full_positions_base.len();
168 for point in trajectory {
169 let mut full_positions = full_positions_base.clone();
170 copy_joint_positions(
171 &self.joint_names(),
172 &point.positions,
173 &self.full_joint_names,
174 &mut full_positions,
175 )?;
176 let mut full_point = TrajectoryPoint::new(full_positions, point.time_from_start);
177 if let Some(partial_velocities) = &point.velocities {
178 let mut full_velocities = vec![0.0; full_dof];
179 copy_joint_positions(
180 &self.joint_names(),
181 partial_velocities,
182 &self.full_joint_names,
183 &mut full_velocities,
184 )?;
185 full_point.velocities = Some(full_velocities);
186 }
187 full_trajectory.push(full_point);
188 }
189 self.shared_client.send_joint_trajectory(full_trajectory)
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use std::sync::{Arc, Mutex};
196
197 use assert_approx_eq::assert_approx_eq;
198
199 use super::*;
200
201 #[derive(Debug, Clone)]
202 struct DummyFull {
203 name: Vec<String>,
204 pos: Arc<Mutex<Vec<f64>>>,
205 last_trajectory: Arc<Mutex<Vec<TrajectoryPoint>>>,
206 }
207 impl JointTrajectoryClient for DummyFull {
208 fn joint_names(&self) -> Vec<String> {
209 self.name.clone()
210 }
211
212 fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
213 Ok(self.pos.lock().unwrap().clone())
214 }
215
216 fn send_joint_positions(
217 &self,
218 positions: Vec<f64>,
219 _duration: std::time::Duration,
220 ) -> Result<WaitFuture, Error> {
221 *self.pos.lock().unwrap() = positions;
222 Ok(WaitFuture::ready())
223 }
224
225 fn send_joint_trajectory(
226 &self,
227 full_trajectory: Vec<TrajectoryPoint>,
228 ) -> Result<WaitFuture, Error> {
229 if let Some(last_point) = full_trajectory.last() {
230 last_point
231 .positions
232 .clone_into(&mut self.pos.lock().unwrap());
233 }
234 *self.last_trajectory.lock().unwrap() = full_trajectory;
235 Ok(WaitFuture::ready())
236 }
237 }
238
239 #[test]
240 fn test_partial_new() {
241 let client = DummyFull {
242 name: vec![
243 String::from("part1"),
244 String::from("high"),
245 String::from("part2"),
246 ],
247 pos: Arc::new(Mutex::new(vec![1.0_f64, 3.0_f64])),
248 last_trajectory: Arc::new(Mutex::new(Vec::new())),
249 };
250
251 let joint_names = vec![
253 String::from("low"),
254 String::from("part2"),
255 String::from("high"),
256 String::from("part1"),
257 ];
258 let partial = PartialJointTrajectoryClient::new(joint_names, client.clone());
259 assert!(partial.is_err());
260
261 let joint_names = vec![
263 String::from("low"),
264 String::from("part2"),
265 String::from("part1"),
266 ];
267 let partial = PartialJointTrajectoryClient::new(joint_names, client.clone());
268 assert!(partial.is_err());
269
270 let joint_names = vec![
272 String::from("high"),
273 String::from("part2"),
274 String::from("part1"),
275 ];
276 let partial = PartialJointTrajectoryClient::new(joint_names, client.clone()).unwrap();
277
278 assert_eq!(
279 format!("{:?}", partial.joint_names),
280 "[\"high\", \"part2\", \"part1\"]"
281 );
282 assert_eq!(
283 format!("{client:?}"),
284 format!("{:?}", partial.shared_client)
285 );
286 assert_eq!(
287 format!("{:?}", partial.full_joint_names),
288 "[\"part1\", \"high\", \"part2\"]"
289 );
290
291 let joint_names = vec![String::from("low"), String::from("high")];
293 let partial = PartialJointTrajectoryClient::new(joint_names, client.clone());
294 assert!(partial.is_err());
295
296 let joint_names = vec![String::from("part1"), String::from("high")];
298 let partial = PartialJointTrajectoryClient::new(joint_names, client.clone()).unwrap();
299
300 assert_eq!(
301 format!("{:?}", partial.joint_names),
302 "[\"part1\", \"high\"]"
303 );
304 assert_eq!(
305 format!("{client:?}"),
306 format!("{:?}", partial.shared_client)
307 );
308 assert_eq!(
309 format!("{:?}", partial.full_joint_names),
310 "[\"part1\", \"high\", \"part2\"]"
311 );
312 }
313
314 #[test]
315 fn test_fn_copy_joint_position_for_from_joint() {
316 let from_positions = vec![2.1_f64, 4.8, 1.0, 6.5];
317 let to_joint_names = vec![
318 String::from("part1"),
319 String::from("part2"),
320 String::from("part3"),
321 String::from("part4"),
322 ];
323
324 let mut to_positions = vec![3.3_f64, 8.1, 5.2, 0.8];
326 let from_joint_names = vec![
327 String::from("part4"),
328 String::from("part1"),
329 String::from("part3"),
330 String::from("part2"),
331 ];
332 let correct = [4.8_f64, 6.5, 1.0, 2.1];
333 let result = copy_joint_positions(
334 &from_joint_names,
335 &from_positions,
336 &to_joint_names,
337 &mut to_positions,
338 );
339 assert!(result.is_ok());
340 println!("{to_positions:?}");
341 to_positions
342 .iter()
343 .zip(correct.iter())
344 .for_each(|(pos, correct)| assert_approx_eq!(*pos, correct));
345
346 let mut to_positions = vec![3.3_f64, 8.1, 5.2, 0.8];
348 let from_joint_names = vec![String::from("part4"), String::from("part1")];
349 let result = copy_joint_positions(
350 &from_joint_names,
351 &from_positions,
352 &to_joint_names,
353 &mut to_positions,
354 );
355 assert!(result.is_err());
356
357 let from_positions = vec![2.1_f64, 4.8];
359 let mut to_positions = vec![3.3_f64, 8.1, 5.2, 0.8];
360 let from_joint_names = vec![String::from("part4"), String::from("part1")];
361 let correct = [4.8_f64, 8.1, 5.2, 2.1];
362 let result = copy_joint_positions(
363 &from_joint_names,
364 &from_positions,
365 &to_joint_names,
366 &mut to_positions,
367 );
368 assert!(result.is_ok());
369 println!("{to_positions:?}");
370 to_positions
371 .iter()
372 .zip(correct.iter())
373 .for_each(|(pos, correct)| assert_approx_eq!(*pos, correct));
374 }
375
376 #[test]
377 fn test_fn_copy_joint_position_for_to_joint() {
378 let from_joint_names = vec![
379 String::from("part1"),
380 String::from("part2"),
381 String::from("part3"),
382 String::from("part4"),
383 ];
384 let from_positions = vec![2.1_f64, 4.8, 1.0, 6.5];
385 let to_joint_names = vec![
386 String::from("part1"),
387 String::from("part2"),
388 String::from("part3"),
389 String::from("part4"),
390 ];
391
392 let mut to_positions = vec![3.3_f64, 8.1, 5.2, 0.8];
394 copy_joint_positions(
395 &from_joint_names,
396 &from_positions,
397 &to_joint_names,
398 &mut to_positions,
399 )
400 .unwrap();
401 to_positions
402 .iter()
403 .zip(from_positions.iter())
404 .for_each(|(pos, correct)| assert_approx_eq!(*pos, correct));
405 println!("{to_positions:?}");
406
407 let mut to_positions = vec![3.3_f64, 8.1, 5.2, 0.8];
409 let to_joint_names = vec![
410 String::from("part4"),
411 String::from("part1"),
412 String::from("part3"),
413 String::from("part2"),
414 ];
415 let correct = [6.5_f64, 2.1, 1.0, 4.8];
416 copy_joint_positions(
417 &from_joint_names,
418 &from_positions,
419 &to_joint_names,
420 &mut to_positions,
421 )
422 .unwrap();
423 to_positions
424 .iter()
425 .zip(correct.iter())
426 .for_each(|(pos, correct)| assert_approx_eq!(*pos, correct));
427 println!("{to_positions:?}");
428
429 let mut to_positions = vec![3.3_f64, 8.1];
431 let to_joint_names = vec![String::from("part4"), String::from("part1")];
432 let correct = [6.5_f64, 2.1];
433 copy_joint_positions(
434 &from_joint_names,
435 &from_positions,
436 &to_joint_names,
437 &mut to_positions,
438 )
439 .unwrap();
440 to_positions
441 .iter()
442 .zip(correct.iter())
443 .for_each(|(pos, correct)| assert_approx_eq!(*pos, correct));
444 println!("{to_positions:?}");
445 }
446
447 #[test]
448 fn test_partial_joint_name() {
449 let client = DummyFull {
450 name: vec![
451 String::from("part1"),
452 String::from("high"),
453 String::from("part2"),
454 String::from("low"),
455 ],
456 pos: Arc::new(Mutex::new(vec![1.0_f64, 3.0_f64])),
457 last_trajectory: Arc::new(Mutex::new(Vec::new())),
458 };
459 let joint_names = vec![String::from("part1"), String::from("high")];
460 let partial = PartialJointTrajectoryClient::new(joint_names, client).unwrap();
461
462 assert_eq!(
463 format!("{:?}", partial.joint_names()),
464 "[\"part1\", \"high\"]"
465 );
466 }
467
468 #[test]
469 fn test_partial_current_pos() {
470 let client = DummyFull {
472 name: vec![
473 String::from("part1"),
474 String::from("high"),
475 String::from("part2"),
476 ],
477 pos: Arc::new(Mutex::new(vec![1.0_f64, 3.0, 2.4])),
478 last_trajectory: Arc::new(Mutex::new(Vec::new())),
479 };
480 let joint_names = vec![
481 String::from("part1"),
482 String::from("high"),
483 String::from("part2"),
484 ];
485 let correct = [1.0_f64, 3.0, 2.4];
486
487 let partial = PartialJointTrajectoryClient::new(joint_names, client.clone()).unwrap();
488 let current_pos = partial.current_joint_positions();
489 assert!(current_pos.is_ok());
490 let current_pos = current_pos.unwrap();
491
492 current_pos
493 .iter()
494 .zip(correct.iter())
495 .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
496
497 let joint_names = vec![String::from("part1"), String::from("part2")];
499 let correct = [1.0_f64, 2.4];
500
501 let partial = PartialJointTrajectoryClient::new(joint_names, client).unwrap();
502 let current_pos = partial.current_joint_positions();
503 assert!(current_pos.is_ok());
504 let current_pos = current_pos.unwrap();
505
506 current_pos
507 .iter()
508 .zip(correct.iter())
509 .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
510 }
511
512 #[tokio::test]
513 async fn test_partial_send_pos() {
514 let client = DummyFull {
515 name: vec![
516 String::from("part1"),
517 String::from("high"),
518 String::from("part2"),
519 ],
520 pos: Arc::new(Mutex::new(vec![1.0_f64, 3.0_f64, 5.0])),
521 last_trajectory: Arc::new(Mutex::new(Vec::new())),
522 };
523 let duration = std::time::Duration::from_secs(5);
524
525 let joint_names = vec![
527 String::from("part1"),
528 String::from("high"),
529 String::from("part2"),
530 ];
531 let next_pos = vec![2.2_f64, 0.5, 1.7];
532
533 let partial =
534 PartialJointTrajectoryClient::new(joint_names.clone(), client.clone()).unwrap();
535
536 let result = partial.send_joint_positions(next_pos.clone(), duration);
537 assert!(result.is_ok());
538 assert!(result.unwrap().await.is_ok());
539
540 let current_pos = partial.current_joint_positions().unwrap();
541 current_pos
542 .iter()
543 .zip(next_pos.iter())
544 .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
545
546 let joint_names = vec![String::from("part2"), String::from("part1")];
548 let next_pos = vec![4.8_f64, 1.5];
549
550 let partial =
551 PartialJointTrajectoryClient::new(joint_names.clone(), client.clone()).unwrap();
552
553 let result = partial.send_joint_positions(next_pos.clone(), duration);
554 assert!(result.is_ok());
555 assert!(result.unwrap().await.is_ok());
556
557 let current_pos = partial.current_joint_positions().unwrap();
558 println!("{current_pos:?}");
559 current_pos
560 .iter()
561 .zip(next_pos.iter())
562 .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
563 }
564
565 #[tokio::test]
566 async fn test_partial_trajectory() {
567 let client = DummyFull {
568 name: vec![
569 String::from("part1"),
570 String::from("high"),
571 String::from("part2"),
572 String::from("low"),
573 ],
574 pos: Arc::new(Mutex::new(vec![5.1_f64, 0.8, 2.4, 4.5])),
575 last_trajectory: Arc::new(Mutex::new(Vec::new())),
576 };
577
578 let joint_names = vec![
580 String::from("part1"),
581 String::from("high"),
582 String::from("part2"),
583 String::from("low"),
584 ];
585 let trajectories = vec![
586 TrajectoryPoint::new(
587 vec![1.0_f64, 3.0_f64, 2.2_f64, 4.0_f64],
588 std::time::Duration::from_secs(1),
589 ),
590 TrajectoryPoint::new(
591 vec![3.4_f64, 5.8_f64, 0.1_f64, 2.5_f64],
592 std::time::Duration::from_secs(2),
593 ),
594 ];
595 let correct = [3.4_f64, 5.8_f64, 0.1_f64, 2.5_f64];
596
597 let partial =
598 PartialJointTrajectoryClient::new(joint_names.clone(), client.clone()).unwrap();
599 let result = partial.send_joint_trajectory(trajectories);
600 assert!(result.is_ok());
601 assert!(result.unwrap().await.is_ok());
602
603 let current_pos = partial.current_joint_positions().unwrap();
604 println!("{current_pos:?}");
605 current_pos
606 .iter()
607 .zip(correct.iter())
608 .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
609
610 let joint_names = vec![
612 String::from("low"),
613 String::from("part2"),
614 String::from("part1"),
615 ];
616 let trajectories = vec![
617 TrajectoryPoint::new(
618 vec![3.4_f64, 5.8_f64, 0.1_f64],
619 std::time::Duration::from_secs(2),
620 ),
621 TrajectoryPoint::new(
622 vec![1.0_f64, 3.0_f64, 2.2_f64],
623 std::time::Duration::from_secs(1),
624 ),
625 ];
626 let correct = [1.0_f64, 3.0_f64, 2.2_f64];
627
628 let partial =
629 PartialJointTrajectoryClient::new(joint_names.clone(), client.clone()).unwrap();
630 let result = partial.send_joint_trajectory(trajectories);
631 assert!(result.is_ok());
632 assert!(result.unwrap().await.is_ok());
633
634 let current_pos = partial.current_joint_positions().unwrap();
635 println!("{current_pos:?}");
636 current_pos
637 .iter()
638 .zip(correct.iter())
639 .for_each(|(pos, correct)| assert_approx_eq!(*pos, *correct));
640 }
641}