arci_ros2/
ros2_control.rs

1use std::{
2    sync::{
3        atomic::{AtomicBool, Ordering},
4        Arc, RwLock,
5    },
6    time::Duration,
7};
8
9use arci::*;
10use futures::stream::StreamExt;
11use r2r::{
12    builtin_interfaces::{msg as builtin_msg, msg::Time},
13    control_msgs::{action::FollowJointTrajectory, msg::JointTrajectoryControllerState},
14    std_msgs::msg::Header,
15    trajectory_msgs::msg as trajectory_msg,
16};
17use serde::{Deserialize, Serialize};
18
19use crate::{utils, Node};
20
21/// `arci::JointTrajectoryClient` implementation for ROS2.
22pub struct Ros2ControlClient {
23    action_client: r2r::ActionClient<FollowJointTrajectory::Action>,
24    // keep not to be dropped
25    _node: Node,
26    joint_names: Vec<String>,
27    joint_state: Arc<RwLock<JointTrajectoryControllerState>>,
28}
29
30impl Ros2ControlClient {
31    /// Creates a new `Ros2ControlClient` from control_msgs/FollowJointTrajectory action name.
32    #[track_caller]
33    pub fn new(node: Node, action_name: &str) -> Result<Self, Error> {
34        // http://wiki.ros.org/joint_trajectory_controller
35        let action_client = node
36            .r2r()
37            .create_action_client::<FollowJointTrajectory::Action>(&format!(
38                "{action_name}/follow_joint_trajectory"
39            ))
40            .map_err(anyhow::Error::from)?;
41
42        let state_topic = format!("{action_name}/state");
43        let mut state_subscriber = node
44            .r2r()
45            .subscribe::<JointTrajectoryControllerState>(&state_topic, r2r::QosProfile::default())
46            .map_err(anyhow::Error::from)?;
47        let Some(joint_state) = utils::subscribe_one(&mut state_subscriber, Duration::from_secs(1))
48        else {
49            return Err(Error::Connection {
50                message: format!("Failed to get joint_state from {state_topic}"),
51            });
52        };
53        let joint_names = joint_state.joint_names.clone();
54        let joint_state = Arc::new(RwLock::new(joint_state));
55        utils::subscribe_thread(state_subscriber, joint_state.clone(), |state| state);
56
57        Ok(Self {
58            action_client,
59            _node: node,
60            joint_names,
61            joint_state,
62        })
63    }
64}
65
66impl JointTrajectoryClient for Ros2ControlClient {
67    fn joint_names(&self) -> Vec<String> {
68        self.joint_names.clone()
69    }
70
71    fn current_joint_positions(&self) -> Result<Vec<f64>, arci::Error> {
72        let joints = self.joint_state.read().unwrap();
73        Ok(self
74            .joint_names
75            .iter()
76            .map(|name| {
77                joints.actual.positions[joints.joint_names.iter().position(|n| n == name).unwrap()]
78            })
79            .collect())
80    }
81
82    fn send_joint_positions(
83        &self,
84        positions: Vec<f64>,
85        duration: Duration,
86    ) -> Result<WaitFuture, arci::Error> {
87        self.send_joint_trajectory(vec![TrajectoryPoint {
88            positions,
89            velocities: None,
90            time_from_start: duration,
91        }])
92    }
93
94    fn send_joint_trajectory(
95        &self,
96        trajectory: Vec<TrajectoryPoint>,
97    ) -> Result<WaitFuture, arci::Error> {
98        let action_client = self.action_client.clone();
99        let is_available = r2r::Node::is_available(&self.action_client).unwrap();
100        let (sender, receiver) = tokio::sync::oneshot::channel();
101        let joint_names = self.joint_names.clone();
102        tokio::spawn(async move {
103            let is_done = Arc::new(AtomicBool::new(false));
104            let is_done_clone = is_done.clone();
105            tokio::spawn(async move {
106                let mut clock = r2r::Clock::create(r2r::ClockType::RosTime).unwrap();
107                let now = clock.get_now().unwrap();
108                let goal = FollowJointTrajectory::Goal {
109                    trajectory: trajectory_msg::JointTrajectory {
110                        joint_names,
111                        points: trajectory
112                            .into_iter()
113                            .map(|tp| trajectory_msg::JointTrajectoryPoint {
114                                velocities: tp
115                                    .velocities
116                                    .unwrap_or_else(|| vec![0.0; tp.positions.len()]),
117                                positions: tp.positions,
118                                time_from_start: builtin_msg::Duration {
119                                    sec: tp
120                                        .time_from_start
121                                        .as_secs()
122                                        .try_into()
123                                        .unwrap_or(i32::MAX),
124                                    nanosec: tp.time_from_start.subsec_nanos(),
125                                },
126                                ..Default::default()
127                            })
128                            .collect(),
129                        header: Header {
130                            stamp: Time {
131                                sec: now.as_secs() as i32,
132                                nanosec: now.subsec_nanos(),
133                            },
134                            ..Default::default()
135                        },
136                    },
137                    ..Default::default()
138                };
139                is_available.await.unwrap();
140                let send_goal_request = action_client.send_goal_request(goal).unwrap();
141                let (_goal, result, feedback) = send_goal_request.await.unwrap();
142                tokio::spawn(async move { feedback.for_each(|_| std::future::ready(())).await });
143                result.await.unwrap(); // TODO: handle goal state
144                is_done.store(true, Ordering::Relaxed);
145            });
146            utils::wait(is_done_clone).await;
147            // TODO: "canceled" should be an error?
148            let _ = sender.send(());
149        });
150        let wait =
151            WaitFuture::new(
152                async move { receiver.await.map_err(|e| arci::Error::Other(e.into())) },
153            );
154        Ok(wait)
155    }
156}
157
158/// Configuration for `Ros2ControlClient`.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160#[serde(deny_unknown_fields)]
161pub struct Ros2ControlConfig {
162    /// Action name for control_msgs/FollowJointTrajectory.
163    pub action_name: String,
164    /// Names of joints.
165    #[serde(default)]
166    pub joint_names: Vec<String>,
167}