arci/
waits.rs

1use std::{
2    fmt,
3    pin::Pin,
4    task::{Context, Poll},
5    time::Duration,
6};
7
8use async_trait::async_trait;
9use auto_impl::auto_impl;
10use futures::{
11    future::{self, BoxFuture, Future, FutureExt},
12    stream::{Stream, TryStreamExt},
13};
14
15use crate::{error::Error, traits::JointTrajectoryClient};
16
17/// Waits until the underlying future is complete.
18#[must_use = "You must explicitly choose whether to wait for the complete or do not wait"]
19pub struct WaitFuture {
20    future: BoxFuture<'static, Result<(), Error>>,
21}
22
23impl WaitFuture {
24    /// Waits until the `future` is complete.
25    pub fn new(future: impl Future<Output = Result<(), Error>> + Send + 'static) -> Self {
26        Self::from(future.boxed())
27    }
28
29    /// Waits until the `stream` is complete.
30    ///
31    /// # Example
32    ///
33    /// ```
34    /// # #[tokio::main]
35    /// # async fn main() -> Result<(), arci::Error> {
36    /// use arci::WaitFuture;
37    /// use futures::stream::FuturesOrdered;
38    ///
39    /// let mut waits = FuturesOrdered::new();
40    /// waits.push(WaitFuture::ready());
41    /// waits.push(WaitFuture::ready());
42    /// WaitFuture::from_stream(waits).await?;
43    /// # Ok(())
44    /// # }
45    /// ```
46    pub fn from_stream(stream: impl Stream<Item = Result<(), Error>> + Send + 'static) -> Self {
47        Self::new(async move {
48            futures::pin_mut!(stream);
49            while stream.try_next().await?.is_some() {}
50            Ok(())
51        })
52    }
53
54    /// Creates a new `WaitFuture` which immediately complete.
55    pub fn ready() -> Self {
56        Self::new(future::ready(Ok(())))
57    }
58}
59
60impl Future for WaitFuture {
61    type Output = Result<(), Error>;
62
63    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
64        self.future.as_mut().poll(cx)
65    }
66}
67
68impl From<BoxFuture<'static, Result<(), Error>>> for WaitFuture {
69    fn from(future: BoxFuture<'static, Result<(), Error>>) -> Self {
70        Self { future }
71    }
72}
73
74impl fmt::Debug for WaitFuture {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        f.debug_struct("WaitFuture").finish()
77    }
78}
79
80#[async_trait]
81#[auto_impl(Box, Arc)]
82pub trait CompleteCondition: Send + Sync {
83    async fn wait(
84        &self,
85        client: &dyn JointTrajectoryClient,
86        target_positions: &[f64],
87        duration_sec: f64,
88    ) -> Result<(), Error>;
89}
90
91#[derive(Clone, Debug)]
92pub struct TotalJointDiffCondition {
93    pub allowable_error: f64,
94    pub timeout_sec: f64,
95}
96
97impl TotalJointDiffCondition {
98    pub fn new(allowable_error: f64, timeout_sec: f64) -> Self {
99        Self {
100            allowable_error,
101            timeout_sec,
102        }
103    }
104}
105
106impl Default for TotalJointDiffCondition {
107    fn default() -> Self {
108        Self::new(0.02, 10.0)
109    }
110}
111
112#[async_trait]
113impl CompleteCondition for TotalJointDiffCondition {
114    async fn wait(
115        &self,
116        client: &dyn JointTrajectoryClient,
117        target_positions: &[f64],
118        duration_sec: f64,
119    ) -> Result<(), Error> {
120        const CHECK_UNIT_SEC: f64 = 0.01;
121        let check_unit_duration: Duration = Duration::from_secs_f64(CHECK_UNIT_SEC);
122        let num_repeat: i32 = ((self.timeout_sec + duration_sec) / CHECK_UNIT_SEC) as i32;
123        for _j in 0..num_repeat {
124            let curs = client.current_joint_positions()?;
125            let sum_err: f64 = target_positions
126                .iter()
127                .zip(curs.iter())
128                .map(|(tar, cur)| (tar - cur).abs())
129                .sum();
130            if sum_err <= self.allowable_error {
131                return Ok(());
132            }
133            tokio::time::sleep(check_unit_duration).await;
134        }
135        Err(Error::TimeoutWithDiff {
136            target: target_positions.to_vec(),
137            current: client.current_joint_positions()?,
138            is_reached: vec![false],
139        })
140    }
141}
142
143#[derive(Clone, Debug)]
144pub struct EachJointDiffCondition {
145    pub allowable_errors: Vec<f64>,
146    pub timeout_sec: f64,
147}
148
149impl EachJointDiffCondition {
150    pub fn new(allowable_errors: Vec<f64>, timeout_sec: f64) -> Self {
151        Self {
152            allowable_errors,
153            timeout_sec,
154        }
155    }
156}
157
158#[async_trait]
159impl CompleteCondition for EachJointDiffCondition {
160    async fn wait(
161        &self,
162        client: &dyn JointTrajectoryClient,
163        target_positions: &[f64],
164        duration_sec: f64,
165    ) -> Result<(), Error> {
166        if target_positions.len() != self.allowable_errors.len() {
167            eprintln!("wait_until_each_error_condition condition size mismatch");
168            return Err(Error::LengthMismatch {
169                model: target_positions.len(),
170                input: self.allowable_errors.len(),
171            });
172        }
173        let dof = target_positions.len();
174        let mut is_reached = vec![false; dof];
175        const CHECK_UNIT_SEC: f64 = 0.01;
176        let check_unit_duration: Duration = Duration::from_secs_f64(CHECK_UNIT_SEC);
177        let num_repeat: i32 = ((self.timeout_sec + duration_sec) / CHECK_UNIT_SEC) as i32;
178
179        for _j in 0..num_repeat {
180            for i in 0..dof {
181                let cur = client.current_joint_positions()?[i];
182                let tar = target_positions[i];
183                if !is_reached[i] {
184                    is_reached[i] = (tar - cur).abs() < self.allowable_errors[i];
185                }
186            }
187            if !is_reached.contains(&false) {
188                return Ok(());
189            }
190            tokio::time::sleep(check_unit_duration).await;
191        }
192        Err(Error::TimeoutWithDiff {
193            target: target_positions.to_vec(),
194            current: client.current_joint_positions()?,
195            is_reached,
196        })
197    }
198}