openrr_client/clients/
collision_check_client.rs

1use std::{path::Path, sync::Arc};
2
3use arci::{Error, JointTrajectoryClient, TrajectoryPoint, WaitFuture};
4use openrr_planner::{
5    SelfCollisionChecker, SelfCollisionCheckerConfig, collision::create_self_collision_checker,
6};
7
8pub struct CollisionCheckClient<T>
9where
10    T: JointTrajectoryClient,
11{
12    pub client: T,
13    pub collision_checker: Arc<SelfCollisionChecker<f64>>,
14}
15
16impl<T> CollisionCheckClient<T>
17where
18    T: JointTrajectoryClient,
19{
20    pub fn new(client: T, collision_checker: Arc<SelfCollisionChecker<f64>>) -> Self {
21        Self {
22            client,
23            collision_checker,
24        }
25    }
26}
27
28impl<T> JointTrajectoryClient for CollisionCheckClient<T>
29where
30    T: JointTrajectoryClient,
31{
32    fn joint_names(&self) -> Vec<String> {
33        self.client.joint_names()
34    }
35
36    fn current_joint_positions(&self) -> Result<Vec<f64>, Error> {
37        self.client.current_joint_positions()
38    }
39
40    fn send_joint_positions(
41        &self,
42        positions: Vec<f64>,
43        duration: std::time::Duration,
44    ) -> Result<WaitFuture, Error> {
45        self.collision_checker
46            .check_partial_joint_positions(
47                self.joint_names().as_slice(),
48                &self.current_joint_positions()?,
49                &positions,
50                duration,
51            )
52            .map_err(|e| Error::Other(e.into()))?;
53        self.client.send_joint_positions(positions, duration)
54    }
55
56    fn send_joint_trajectory(&self, trajectory: Vec<TrajectoryPoint>) -> Result<WaitFuture, Error> {
57        let position_trajectory = trajectory
58            .iter()
59            .map(|point| {
60                openrr_planner::TrajectoryPoint::new(point.positions.clone(), vec![], vec![])
61            })
62            .collect::<Vec<_>>();
63        self.collision_checker
64            .check_partial_joint_trajectory(self.joint_names().as_slice(), &position_trajectory)
65            .map_err(|e| Error::Other(e.into()))?;
66        self.client.send_joint_trajectory(trajectory)
67    }
68}
69
70pub fn create_collision_check_client<P: AsRef<Path>>(
71    urdf_path: P,
72    self_collision_check_pairs: &[String],
73    config: &SelfCollisionCheckerConfig,
74    client: Arc<dyn JointTrajectoryClient>,
75    reference_robot: Arc<k::Chain<f64>>,
76) -> CollisionCheckClient<Arc<dyn JointTrajectoryClient>> {
77    let collision_checker = Arc::new(create_self_collision_checker(
78        urdf_path,
79        self_collision_check_pairs,
80        config,
81        reference_robot,
82    ));
83
84    CollisionCheckClient::new(client, collision_checker)
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[tokio::test]
92    async fn test_create_collision_check_client() {
93        let urdf_path = Path::new("sample.urdf");
94        let urdf_robot = urdf_rs::read_file(urdf_path).unwrap();
95        let robot = Arc::new(k::Chain::<f64>::from(&urdf_robot));
96        let client = arci::DummyJointTrajectoryClient::new(
97            robot
98                .iter_joints()
99                .map(|joint| joint.name.clone())
100                .collect(),
101        );
102        client
103            .send_joint_positions(vec![0.0; 8], std::time::Duration::new(0, 0))
104            .unwrap()
105            .await
106            .unwrap();
107
108        let collision_check_client = create_collision_check_client(
109            urdf_path,
110            &["root:l_shoulder_roll".into()],
111            &SelfCollisionCheckerConfig::default(),
112            Arc::new(client),
113            robot,
114        );
115
116        assert_eq!(
117            *collision_check_client.current_joint_positions().unwrap(),
118            vec![0.0; 8]
119        );
120
121        assert!(
122            collision_check_client
123                .send_joint_positions(vec![0.0; 8], std::time::Duration::new(1, 0),)
124                .is_ok()
125        );
126        assert!(
127            collision_check_client
128                .send_joint_positions(
129                    vec![1.57, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
130                    std::time::Duration::new(1, 0),
131                )
132                .is_err()
133        );
134    }
135
136    #[tokio::test]
137    async fn test_create_collision_check_client_for_partial_joints() {
138        let urdf_path = Path::new("sample.urdf");
139        let urdf_robot = urdf_rs::read_file(urdf_path).unwrap();
140        let robot = Arc::new(k::Chain::<f64>::from(&urdf_robot));
141        let client = arci::DummyJointTrajectoryClient::new(
142            robot
143                .iter_joints()
144                .take(2)
145                .map(|joint| joint.name.clone())
146                .collect(),
147        );
148        client
149            .send_joint_positions(vec![0.0; 2], std::time::Duration::new(0, 0))
150            .unwrap()
151            .await
152            .unwrap();
153
154        let collision_check_client = create_collision_check_client(
155            urdf_path,
156            &["root:l_shoulder_roll".into()],
157            &SelfCollisionCheckerConfig::default(),
158            Arc::new(client),
159            robot,
160        );
161
162        assert_eq!(
163            *collision_check_client.current_joint_positions().unwrap(),
164            vec![0.0; 2]
165        );
166
167        assert!(
168            collision_check_client
169                .send_joint_positions(vec![0.0; 2], std::time::Duration::new(1, 0),)
170                .is_ok()
171        );
172        assert!(
173            collision_check_client
174                .send_joint_positions(vec![1.57, 0.0], std::time::Duration::new(1, 0),)
175                .is_err()
176        );
177
178        let point_ok = TrajectoryPoint::new([0.0; 2].to_vec(), std::time::Duration::new(1, 0));
179        let trajectory_ok = vec![point_ok];
180        assert!(
181            collision_check_client
182                .send_joint_trajectory(trajectory_ok)
183                .is_ok()
184        );
185
186        let point_err = TrajectoryPoint::new([1.57, 0.0].to_vec(), std::time::Duration::new(2, 0));
187        let trajectory_err = vec![point_err];
188        assert!(
189            collision_check_client
190                .send_joint_trajectory(trajectory_err)
191                .is_err()
192        );
193    }
194}