openrr_client/clients/
collision_check_client.rs

1use std::{path::Path, sync::Arc};
2
3use arci::{Error, JointTrajectoryClient, TrajectoryPoint, WaitFuture};
4use openrr_planner::{
5    collision::create_self_collision_checker, SelfCollisionChecker, SelfCollisionCheckerConfig,
6};
7
8pub struct CollisionCheckClient<T>
9where
10    T: JointTrajectoryClient,
11{
12    pub client: T,
13    pub collision_checker: Arc<SelfCollisionChecker<f64>>,
14}
15
16impl<T> CollisionCheckClient<T>
17where
18    T: JointTrajectoryClient,
19{
20    pub fn new(client: T, collision_checker: Arc<SelfCollisionChecker<f64>>) -> Self {
21        Self {
22            client,
23            collision_checker,
24        }
25    }
26}
27
28impl<T> JointTrajectoryClient for CollisionCheckClient<T>
29where
30    T: JointTrajectoryClient,
31{
32    fn joint_names(&self) -> Vec<String> {
33        self.client.joint_names()
34    }
35
36    fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
37        self.client.current_joint_positions()
38    }
39
40    fn send_joint_positions(
41        &self,
42        positions: Vec<f64>,
43        duration: std::time::Duration,
44    ) -> Result<WaitFuture, Error> {
45        self.collision_checker
46            .check_partial_joint_positions(
47                self.joint_names().as_slice(),
48                &self.current_joint_positions()?,
49                &positions,
50                duration,
51            )
52            .map_err(|e| Error::Other(e.into()))?;
53        self.client.send_joint_positions(positions, duration)
54    }
55
56    fn send_joint_trajectory(&self, trajectory: Vec<TrajectoryPoint>) -> Result<WaitFuture, Error> {
57        let position_trajectory = trajectory
58            .iter()
59            .map(|point| {
60                openrr_planner::TrajectoryPoint::new(point.positions.clone(), vec![], vec![])
61            })
62            .collect::<Vec<_>>();
63        self.collision_checker
64            .check_partial_joint_trajectory(self.joint_names().as_slice(), &position_trajectory)
65            .map_err(|e| Error::Other(e.into()))?;
66        self.client.send_joint_trajectory(trajectory)
67    }
68}
69
70pub fn create_collision_check_client<P: AsRef<Path>>(
71    urdf_path: P,
72    self_collision_check_pairs: &[String],
73    config: &SelfCollisionCheckerConfig,
74    client: Arc<dyn JointTrajectoryClient>,
75    reference_robot: Arc<k::Chain<f64>>,
76) -> CollisionCheckClient<Arc<dyn JointTrajectoryClient>> {
77    let collision_checker = Arc::new(create_self_collision_checker(
78        urdf_path,
79        self_collision_check_pairs,
80        config,
81        reference_robot,
82    ));
83
84    CollisionCheckClient::new(client, collision_checker)
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[tokio::test]
92    async fn test_create_collision_check_client() {
93        let urdf_path = Path::new("sample.urdf");
94        let urdf_robot = urdf_rs::read_file(urdf_path).unwrap();
95        let robot = Arc::new(k::Chain::<f64>::from(&urdf_robot));
96        let client = arci::DummyJointTrajectoryClient::new(
97            robot
98                .iter_joints()
99                .map(|joint| joint.name.clone())
100                .collect(),
101        );
102        client
103            .send_joint_positions(vec![0.0; 8], std::time::Duration::new(0, 0))
104            .unwrap()
105            .await
106            .unwrap();
107
108        let collision_check_client = create_collision_check_client(
109            urdf_path,
110            &["root:l_shoulder_roll".into()],
111            &SelfCollisionCheckerConfig::default(),
112            Arc::new(client),
113            robot,
114        );
115
116        assert_eq!(
117            *collision_check_client.current_joint_positions().unwrap(),
118            vec![0.0; 8]
119        );
120
121        assert!(collision_check_client
122            .send_joint_positions(vec![0.0; 8], std::time::Duration::new(1, 0),)
123            .is_ok());
124        assert!(collision_check_client
125            .send_joint_positions(
126                vec![1.57, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
127                std::time::Duration::new(1, 0),
128            )
129            .is_err());
130    }
131
132    #[tokio::test]
133    async fn test_create_collision_check_client_for_partial_joints() {
134        let urdf_path = Path::new("sample.urdf");
135        let urdf_robot = urdf_rs::read_file(urdf_path).unwrap();
136        let robot = Arc::new(k::Chain::<f64>::from(&urdf_robot));
137        let client = arci::DummyJointTrajectoryClient::new(
138            robot
139                .iter_joints()
140                .take(2)
141                .map(|joint| joint.name.clone())
142                .collect(),
143        );
144        client
145            .send_joint_positions(vec![0.0; 2], std::time::Duration::new(0, 0))
146            .unwrap()
147            .await
148            .unwrap();
149
150        let collision_check_client = create_collision_check_client(
151            urdf_path,
152            &["root:l_shoulder_roll".into()],
153            &SelfCollisionCheckerConfig::default(),
154            Arc::new(client),
155            robot,
156        );
157
158        assert_eq!(
159            *collision_check_client.current_joint_positions().unwrap(),
160            vec![0.0; 2]
161        );
162
163        assert!(collision_check_client
164            .send_joint_positions(vec![0.0; 2], std::time::Duration::new(1, 0),)
165            .is_ok());
166        assert!(collision_check_client
167            .send_joint_positions(vec![1.57, 0.0], std::time::Duration::new(1, 0),)
168            .is_err());
169
170        let point_ok = TrajectoryPoint::new([0.0; 2].to_vec(), std::time::Duration::new(1, 0));
171        let trajectory_ok = vec![point_ok];
172        assert!(collision_check_client
173            .send_joint_trajectory(trajectory_ok)
174            .is_ok());
175
176        let point_err = TrajectoryPoint::new([1.57, 0.0].to_vec(), std::time::Duration::new(2, 0));
177        let trajectory_err = vec![point_err];
178        assert!(collision_check_client
179            .send_joint_trajectory(trajectory_err)
180            .is_err());
181    }
182}