1use std::{collections::HashMap, sync::Mutex, time::Duration};
2
3use arci::{
4 gamepad::{Button, GamepadEvent},
5 JointTrajectoryClient, Speaker,
6};
7use async_trait::async_trait;
8use openrr_client::JointsPose;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11
12use crate::ControlMode;
13
14struct JointsPoseSenderInner {
15 joints_poses: Vec<JointsPose>,
16 submode: String,
17 pose_index: usize,
18 is_trigger_holding: bool,
19 is_sending: bool,
20}
21
22impl JointsPoseSenderInner {
23 fn new(joints_poses: Vec<JointsPose>) -> Self {
24 Self {
25 submode: format!(
26 " {} {}",
27 joints_poses[0].client_name, joints_poses[0].pose_name
28 ),
29 joints_poses,
30 pose_index: 0,
31 is_trigger_holding: false,
32 is_sending: false,
33 }
34 }
35
36 fn handle_event(&mut self, event: arci::gamepad::GamepadEvent) -> Option<&str> {
37 match event {
38 GamepadEvent::ButtonPressed(Button::East) => {
39 self.pose_index = (self.pose_index + 1) % self.joints_poses.len();
40 let joints_pose = &self.joints_poses[self.pose_index];
41 self.submode = format!(" {} {}", joints_pose.client_name, joints_pose.pose_name);
42 return Some(&self.submode);
43 }
44 GamepadEvent::ButtonPressed(Button::RightTrigger2) => {
45 self.is_trigger_holding = true;
46 }
47 GamepadEvent::ButtonReleased(Button::RightTrigger2) => {
48 self.is_trigger_holding = false;
49 self.is_sending = false;
50 }
51 GamepadEvent::ButtonPressed(Button::West) => {
52 self.is_sending = true;
53 }
54 GamepadEvent::ButtonReleased(Button::West) => {
55 self.is_sending = false;
56 }
57 GamepadEvent::Disconnected => {
58 self.is_trigger_holding = false;
59 self.is_sending = false;
60 }
61 _ => {}
62 }
63 None
64 }
65
66 fn get_target_name_positions(&self) -> (String, Vec<f64>) {
67 let joints_pose = &self.joints_poses[self.pose_index];
68 (
69 joints_pose.client_name.to_owned(),
70 joints_pose.positions.to_owned(),
71 )
72 }
73}
74pub struct JointsPoseSender<S, J>
75where
76 S: Speaker,
77 J: JointTrajectoryClient,
78{
79 mode: String,
80 joint_trajectory_clients: HashMap<String, J>,
81 speaker: S,
82 duration: Duration,
83 inner: Mutex<JointsPoseSenderInner>,
84}
85
86impl<S, J> JointsPoseSender<S, J>
87where
88 S: Speaker,
89 J: JointTrajectoryClient,
90{
91 pub fn new(
92 mode: String,
93 joints_poses: Vec<JointsPose>,
94 joint_trajectory_clients: HashMap<String, J>,
95 speaker: S,
96 duration: Duration,
97 ) -> Self {
98 Self {
99 mode,
100 joint_trajectory_clients,
101 speaker,
102 duration,
103 inner: Mutex::new(JointsPoseSenderInner::new(joints_poses)),
104 }
105 }
106
107 pub fn new_from_config(
108 config: JointsPoseSenderConfig,
109 joints_poses: Vec<JointsPose>,
110 joint_trajectory_clients: HashMap<String, J>,
111 speaker: S,
112 ) -> Self {
113 Self::new(
114 config.mode,
115 joints_poses,
116 joint_trajectory_clients,
117 speaker,
118 Duration::from_secs_f64(config.duration_secs),
119 )
120 }
121}
122
123#[async_trait]
124impl<S, J> ControlMode for JointsPoseSender<S, J>
125where
126 S: Speaker,
127 J: JointTrajectoryClient,
128{
129 fn handle_event(&self, event: arci::gamepad::GamepadEvent) {
130 if let Some(submode) = self.inner.lock().unwrap().handle_event(event) {
131 drop(
133 self.speaker
134 .speak(&format!("{}{submode}", self.mode))
135 .unwrap(),
136 );
137 }
138 }
139
140 async fn proc(&self) {
141 let inner = self.inner.lock().unwrap();
142 let (name, target) = inner.get_target_name_positions();
143 let client = self.joint_trajectory_clients.get(&name).unwrap();
144 drop(
145 client
146 .send_joint_positions(
147 if inner.is_sending && inner.is_trigger_holding {
148 target
149 } else {
150 client.current_joint_positions().unwrap()
151 },
152 self.duration,
153 )
154 .unwrap(),
155 );
156 }
157
158 fn mode(&self) -> &str {
159 &self.mode
160 }
161
162 fn submode(&self) -> String {
163 self.inner.lock().unwrap().submode.to_owned()
164 }
165}
166
167#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
168#[serde(deny_unknown_fields)]
169pub struct JointsPoseSenderConfig {
170 #[serde(default = "default_mode")]
171 pub mode: String,
172 #[serde(default = "default_duration_secs")]
173 pub duration_secs: f64,
174}
175
176fn default_mode() -> String {
177 "pose".to_string()
178}
179
180fn default_duration_secs() -> f64 {
181 2.0
182}
183
184#[cfg(test)]
185mod test {
186 use arci::DummyJointTrajectoryClient;
187 use assert_approx_eq::*;
188 use openrr_client::PrintSpeaker;
189
190 use super::*;
191
192 #[test]
193 fn test_joints_pose_sender_inner() {
194 let joints_poses = vec![
195 JointsPose {
196 pose_name: String::from("pose0"),
197 client_name: String::from("client0"),
198 positions: vec![1.2, 3.4, 5.6],
199 },
200 JointsPose {
201 pose_name: String::from("pose1"),
202 client_name: String::from("client1"),
203 positions: vec![7.8, 9.1, 2.3],
204 },
205 ];
206 let mut inner = JointsPoseSenderInner::new(joints_poses);
207
208 assert_eq!(inner.get_target_name_positions().0, String::from("client0"));
209 assert_approx_eq!(inner.get_target_name_positions().1[0], 1.2f64);
210 assert_approx_eq!(inner.get_target_name_positions().1[1], 3.4f64);
211 assert_approx_eq!(inner.get_target_name_positions().1[2], 5.6f64);
212
213 inner.handle_event(GamepadEvent::ButtonPressed(Button::East));
215 assert_eq!(inner.submode, String::from(" client1 pose1"));
216 assert_eq!(inner.pose_index, 1);
217
218 inner.handle_event(GamepadEvent::ButtonPressed(Button::RightTrigger2));
220 assert!(inner.is_trigger_holding);
221 assert!(!inner.is_sending);
222
223 inner.handle_event(GamepadEvent::ButtonReleased(Button::RightTrigger2));
225 assert!(!inner.is_trigger_holding);
226 assert!(!inner.is_sending);
227
228 inner.handle_event(GamepadEvent::ButtonPressed(Button::West));
230 assert!(inner.is_sending);
231
232 inner.handle_event(GamepadEvent::ButtonReleased(Button::West));
234 assert!(!inner.is_sending);
235
236 inner.handle_event(GamepadEvent::Disconnected);
238 assert!(!inner.is_trigger_holding);
239 assert!(!inner.is_sending);
240 }
241
242 #[test]
243 fn test_joints_pose_sender() {
244 let mode = String::from("test_mode");
245 let joints_poses = vec![JointsPose {
246 pose_name: String::from("pose0"),
247 client_name: String::from("client0"),
248 positions: vec![1.2, 3.4, 5.6],
249 }];
250 let joint_names = vec![String::from("test_joint1")];
251 let joint_trajectory_clients = HashMap::from([(
252 String::from("client0"),
253 DummyJointTrajectoryClient::new(joint_names.clone()),
254 )]);
255 let speaker = PrintSpeaker::new();
256 let duration = Duration::from_millis(5);
257
258 let inner = Mutex::new(JointsPoseSenderInner::new(joints_poses.clone()));
259
260 let joints_pose_sender = JointsPoseSender::new(
261 mode.clone(),
262 joints_poses,
263 joint_trajectory_clients,
264 speaker,
265 duration,
266 );
267
268 assert_eq!(joints_pose_sender.mode, mode);
270 assert_eq!(
271 format!("{:?}", DummyJointTrajectoryClient::new(joint_names)),
272 format!(
273 "{:?}",
274 joints_pose_sender.joint_trajectory_clients["client0"]
275 )
276 );
277 assert!(duration.eq(&joints_pose_sender.duration));
278 assert_eq!(
279 format!(
280 "{:?}",
281 joints_pose_sender.inner.lock().unwrap().joints_poses
282 ),
283 format!("{:?}", inner.lock().unwrap().joints_poses)
284 );
285
286 let joints_pose_sender_mode = joints_pose_sender.mode();
288 assert_eq!(joints_pose_sender_mode, mode);
289 let submode = joints_pose_sender.submode();
290 assert_eq!(joints_pose_sender.inner.lock().unwrap().submode, submode);
291 }
292
293 #[tokio::test]
294 async fn test_joints_pose_sender_proc() {
295 let joints_poses = vec![JointsPose {
296 pose_name: String::from("pose0"),
297 client_name: String::from("client0"),
298 positions: vec![1.2, 3.4, 5.6],
299 }];
300 let joint_names = vec![
301 String::from("joint1"),
302 String::from("joint2"),
303 String::from("joint3"),
304 ];
305
306 let joints_pose_sender = JointsPoseSender {
307 mode: String::from("test_mode"),
308 joint_trajectory_clients: HashMap::from([(
309 String::from("client0"),
310 DummyJointTrajectoryClient::new(joint_names.clone()),
311 )]),
312 speaker: PrintSpeaker::new(),
313 duration: Duration::from_millis(5),
314 inner: Mutex::new(JointsPoseSenderInner {
315 joints_poses: joints_poses.clone(),
316 submode: String::from("submode"),
317 pose_index: 0,
318 is_trigger_holding: false,
319 is_sending: false,
320 }),
321 };
322 joints_pose_sender.proc().await;
323
324 let current_state = joints_pose_trajectory_client_current_state(&joints_pose_sender);
325
326 assert_approx_eq!(current_state[0], 0.0);
327 assert_approx_eq!(current_state[1], 0.0);
328 assert_approx_eq!(current_state[2], 0.0);
329
330 let joints_pose_sender = JointsPoseSender {
331 mode: String::from("test_mode"),
332 joint_trajectory_clients: HashMap::from([(
333 String::from("client0"),
334 DummyJointTrajectoryClient::new(joint_names.clone()),
335 )]),
336 speaker: PrintSpeaker::new(),
337 duration: Duration::from_millis(5),
338 inner: Mutex::new(JointsPoseSenderInner {
339 joints_poses: joints_poses.clone(),
340 submode: String::from("submode"),
341 pose_index: 0,
342 is_trigger_holding: true,
343 is_sending: true,
344 }),
345 };
346 joints_pose_sender.proc().await;
347
348 let current_state = joints_pose_trajectory_client_current_state(&joints_pose_sender);
349
350 assert_approx_eq!(current_state[0], 1.2);
351 assert_approx_eq!(current_state[1], 3.4);
352 assert_approx_eq!(current_state[2], 5.6);
353 }
354
355 fn joints_pose_trajectory_client_current_state(
356 joints_pose_sender: &JointsPoseSender<PrintSpeaker, DummyJointTrajectoryClient>,
357 ) -> Vec<f64> {
358 let inner = joints_pose_sender.inner.lock().unwrap();
359 let (name, target) = inner.get_target_name_positions();
360
361 let client = joints_pose_sender
362 .joint_trajectory_clients
363 .get(&name)
364 .unwrap();
365 drop(
366 client
367 .send_joint_positions(
368 if inner.is_sending && inner.is_trigger_holding {
369 target
370 } else {
371 client.current_joint_positions().unwrap()
372 },
373 joints_pose_sender.duration,
374 )
375 .unwrap(),
376 );
377
378 joints_pose_sender.joint_trajectory_clients[&name]
379 .current_joint_positions()
380 .unwrap()
381 }
382
383 #[test]
384 fn test_default_mode() {
385 assert_eq!(default_mode(), "pose".to_string());
386 }
387
388 #[test]
389 fn test_default_duration_secs() {
390 let def = default_duration_secs();
391 assert_approx_eq!(def, 2.0f64);
392 }
393}