openrr_client/clients/
chain_wrapper.rs

1use std::sync::Arc;
2
3use arci::{JointTrajectoryClient, WaitFuture};
4
5use crate::utils::find_nodes;
6
7pub struct ChainWrapper {
8    joint_names: Vec<String>,
9    full_chain: Arc<k::Chain<f64>>,
10    nodes: Vec<k::Node<f64>>,
11}
12
13impl ChainWrapper {
14    pub fn new(joint_names: Vec<String>, full_chain: Arc<k::Chain<f64>>) -> Self {
15        let nodes = find_nodes(&joint_names, full_chain.as_ref()).unwrap();
16        Self {
17            joint_names,
18            full_chain,
19            nodes,
20        }
21    }
22}
23
24impl JointTrajectoryClient for ChainWrapper {
25    fn joint_names(&self) -> Vec<String> {
26        self.joint_names.clone()
27    }
28
29    fn current_joint_positions(&self) -> Result<Vec<f64>, arci::Error> {
30        self.full_chain.update_transforms();
31        let mut positions = vec![0.0; self.joint_names.len()];
32        for (index, node) in self.nodes.iter().enumerate() {
33            positions[index] = node
34                .joint_position()
35                .ok_or_else(|| anyhow::anyhow!("No joint_position for joint={node}"))?;
36        }
37        Ok(positions)
38    }
39
40    fn send_joint_positions(
41        &self,
42        positions: Vec<f64>,
43        _duration: std::time::Duration,
44    ) -> Result<WaitFuture, arci::Error> {
45        for (index, node) in self.nodes.iter().enumerate() {
46            node.set_joint_position_clamped(positions[index]);
47        }
48        self.full_chain.update_transforms();
49        Ok(WaitFuture::ready())
50    }
51
52    fn send_joint_trajectory(
53        &self,
54        trajectory: Vec<arci::TrajectoryPoint>,
55    ) -> Result<WaitFuture, arci::Error> {
56        if let Some(last_point) = trajectory.last() {
57            self.send_joint_positions(last_point.positions.clone(), last_point.time_from_start)
58        } else {
59            Ok(WaitFuture::ready())
60        }
61    }
62}
63
64#[cfg(test)]
65mod test {
66    use std::time::Duration;
67
68    use arci::{TrajectoryPoint, Vector3};
69    use assert_approx_eq::assert_approx_eq;
70    use k::{Joint, JointType, Node};
71
72    use super::*;
73
74    #[test]
75    fn test_chain_wrapper() {
76        let chain_wrapper = ChainWrapper::new(
77            vec![String::from("joint")],
78            Arc::new(k::Chain::from_nodes(vec![Node::new(Joint::new(
79                "joint",
80                JointType::Linear {
81                    axis: Vector3::y_axis(),
82                },
83            ))])),
84        );
85
86        assert_eq!(chain_wrapper.joint_names(), vec![String::from("joint")]);
87
88        drop(
89            chain_wrapper
90                .send_joint_positions(vec![1.0], Duration::from_secs(1))
91                .unwrap(),
92        );
93        assert_approx_eq!(chain_wrapper.current_joint_positions().unwrap()[0], 1.0);
94
95        drop(
96            chain_wrapper
97                .send_joint_trajectory(vec![TrajectoryPoint {
98                    positions: vec![1.5],
99                    velocities: Some(vec![1.0]),
100                    time_from_start: Duration::from_secs(1),
101                }])
102                .unwrap(),
103        );
104        assert_approx_eq!(chain_wrapper.current_joint_positions().unwrap()[0], 1.5);
105    }
106}