1use std::{
2 fmt,
3 pin::Pin,
4 task::{Context, Poll},
5 time::Duration,
6};
7
8use async_trait::async_trait;
9use auto_impl::auto_impl;
10use futures::{
11 future::{self, BoxFuture, Future, FutureExt},
12 stream::{Stream, TryStreamExt},
13};
14
15use crate::{error::Error, traits::JointTrajectoryClient};
16
17#[must_use = "You must explicitly choose whether to wait for the complete or do not wait"]
19pub struct WaitFuture {
20 future: BoxFuture<'static, Result<(), Error>>,
21}
22
23impl WaitFuture {
24 pub fn new(future: impl Future<Output = Result<(), Error>> + Send + 'static) -> Self {
26 Self::from(future.boxed())
27 }
28
29 pub fn from_stream(stream: impl Stream<Item = Result<(), Error>> + Send + 'static) -> Self {
47 Self::new(async move {
48 futures::pin_mut!(stream);
49 while stream.try_next().await?.is_some() {}
50 Ok(())
51 })
52 }
53
54 pub fn ready() -> Self {
56 Self::new(future::ready(Ok(())))
57 }
58}
59
60impl Future for WaitFuture {
61 type Output = Result<(), Error>;
62
63 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
64 self.future.as_mut().poll(cx)
65 }
66}
67
68impl From<BoxFuture<'static, Result<(), Error>>> for WaitFuture {
69 fn from(future: BoxFuture<'static, Result<(), Error>>) -> Self {
70 Self { future }
71 }
72}
73
74impl fmt::Debug for WaitFuture {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 f.debug_struct("WaitFuture").finish()
77 }
78}
79
80#[async_trait]
81#[auto_impl(Box, Arc)]
82pub trait CompleteCondition: Send + Sync {
83 async fn wait(
84 &self,
85 client: &dyn JointTrajectoryClient,
86 target_positions: &[f64],
87 duration_sec: f64,
88 ) -> Result<(), Error>;
89}
90
91#[derive(Clone, Debug)]
92pub struct TotalJointDiffCondition {
93 pub allowable_error: f64,
94 pub timeout_sec: f64,
95}
96
97impl TotalJointDiffCondition {
98 pub fn new(allowable_error: f64, timeout_sec: f64) -> Self {
99 Self {
100 allowable_error,
101 timeout_sec,
102 }
103 }
104}
105
106impl Default for TotalJointDiffCondition {
107 fn default() -> Self {
108 Self::new(0.02, 10.0)
109 }
110}
111
112#[async_trait]
113impl CompleteCondition for TotalJointDiffCondition {
114 async fn wait(
115 &self,
116 client: &dyn JointTrajectoryClient,
117 target_positions: &[f64],
118 duration_sec: f64,
119 ) -> Result<(), Error> {
120 const CHECK_UNIT_SEC: f64 = 0.01;
121 let check_unit_duration: Duration = Duration::from_secs_f64(CHECK_UNIT_SEC);
122 let num_repeat: i32 = ((self.timeout_sec + duration_sec) / CHECK_UNIT_SEC) as i32;
123 for _j in 0..num_repeat {
124 let curs = client.current_joint_positions()?;
125 let sum_err: f64 = target_positions
126 .iter()
127 .zip(curs.iter())
128 .map(|(tar, cur)| (tar - cur).abs())
129 .sum();
130 if sum_err <= self.allowable_error {
131 return Ok(());
132 }
133 tokio::time::sleep(check_unit_duration).await;
134 }
135 Err(Error::TimeoutWithDiff {
136 target: target_positions.to_vec(),
137 current: client.current_joint_positions()?,
138 is_reached: vec![false],
139 })
140 }
141}
142
143#[derive(Clone, Debug)]
144pub struct EachJointDiffCondition {
145 pub allowable_errors: Vec<f64>,
146 pub timeout_sec: f64,
147}
148
149impl EachJointDiffCondition {
150 pub fn new(allowable_errors: Vec<f64>, timeout_sec: f64) -> Self {
151 Self {
152 allowable_errors,
153 timeout_sec,
154 }
155 }
156}
157
158#[async_trait]
159impl CompleteCondition for EachJointDiffCondition {
160 async fn wait(
161 &self,
162 client: &dyn JointTrajectoryClient,
163 target_positions: &[f64],
164 duration_sec: f64,
165 ) -> Result<(), Error> {
166 if target_positions.len() != self.allowable_errors.len() {
167 eprintln!("wait_until_each_error_condition condition size mismatch");
168 return Err(Error::LengthMismatch {
169 model: target_positions.len(),
170 input: self.allowable_errors.len(),
171 });
172 }
173 let dof = target_positions.len();
174 let mut is_reached = vec![false; dof];
175 const CHECK_UNIT_SEC: f64 = 0.01;
176 let check_unit_duration: Duration = Duration::from_secs_f64(CHECK_UNIT_SEC);
177 let num_repeat: i32 = ((self.timeout_sec + duration_sec) / CHECK_UNIT_SEC) as i32;
178
179 for _j in 0..num_repeat {
180 for i in 0..dof {
181 let cur = client.current_joint_positions()?[i];
182 let tar = target_positions[i];
183 if !is_reached[i] {
184 is_reached[i] = (tar - cur).abs() < self.allowable_errors[i];
185 }
186 }
187 if !is_reached.contains(&false) {
188 return Ok(());
189 }
190 tokio::time::sleep(check_unit_duration).await;
191 }
192 Err(Error::TimeoutWithDiff {
193 target: target_positions.to_vec(),
194 current: client.current_joint_positions()?,
195 is_reached,
196 })
197 }
198}