use std::{
fmt,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use async_trait::async_trait;
use auto_impl::auto_impl;
use futures::{
future::{self, BoxFuture, Future, FutureExt},
stream::{Stream, TryStreamExt},
};
use crate::{error::Error, traits::JointTrajectoryClient};
#[must_use = "You must explicitly choose whether to wait for the complete or do not wait"]
pub struct WaitFuture {
future: BoxFuture<'static, Result<(), Error>>,
}
impl WaitFuture {
pub fn new(future: impl Future<Output = Result<(), Error>> + Send + 'static) -> Self {
Self::from(future.boxed())
}
pub fn from_stream(stream: impl Stream<Item = Result<(), Error>> + Send + 'static) -> Self {
Self::new(async move {
futures::pin_mut!(stream);
while stream.try_next().await?.is_some() {}
Ok(())
})
}
pub fn ready() -> Self {
Self::new(future::ready(Ok(())))
}
}
impl Future for WaitFuture {
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.future.as_mut().poll(cx)
}
}
impl From<BoxFuture<'static, Result<(), Error>>> for WaitFuture {
fn from(future: BoxFuture<'static, Result<(), Error>>) -> Self {
Self { future }
}
}
impl fmt::Debug for WaitFuture {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WaitFuture").finish()
}
}
#[async_trait]
#[auto_impl(Box, Arc)]
pub trait CompleteCondition: Send + Sync {
async fn wait(
&self,
client: &dyn JointTrajectoryClient,
target_positions: &[f64],
duration_sec: f64,
) -> Result<(), Error>;
}
#[derive(Clone, Debug)]
pub struct TotalJointDiffCondition {
pub allowable_error: f64,
pub timeout_sec: f64,
}
impl TotalJointDiffCondition {
pub fn new(allowable_error: f64, timeout_sec: f64) -> Self {
Self {
allowable_error,
timeout_sec,
}
}
}
impl Default for TotalJointDiffCondition {
fn default() -> Self {
Self::new(0.02, 10.0)
}
}
#[async_trait]
impl CompleteCondition for TotalJointDiffCondition {
async fn wait(
&self,
client: &dyn JointTrajectoryClient,
target_positions: &[f64],
duration_sec: f64,
) -> Result<(), Error> {
const CHECK_UNIT_SEC: f64 = 0.01;
let check_unit_duration: Duration = Duration::from_secs_f64(CHECK_UNIT_SEC);
let num_repeat: i32 = ((self.timeout_sec + duration_sec) / CHECK_UNIT_SEC) as i32;
for _j in 0..num_repeat {
let curs = client.current_joint_positions()?;
let sum_err: f64 = target_positions
.iter()
.zip(curs.iter())
.map(|(tar, cur)| (tar - cur).abs())
.sum();
if sum_err <= self.allowable_error {
return Ok(());
}
tokio::time::sleep(check_unit_duration).await;
}
Err(Error::TimeoutWithDiff {
target: target_positions.to_vec(),
current: client.current_joint_positions()?,
is_reached: vec![false],
})
}
}
#[derive(Clone, Debug)]
pub struct EachJointDiffCondition {
pub allowable_errors: Vec<f64>,
pub timeout_sec: f64,
}
impl EachJointDiffCondition {
pub fn new(allowable_errors: Vec<f64>, timeout_sec: f64) -> Self {
Self {
allowable_errors,
timeout_sec,
}
}
}
#[async_trait]
impl CompleteCondition for EachJointDiffCondition {
async fn wait(
&self,
client: &dyn JointTrajectoryClient,
target_positions: &[f64],
duration_sec: f64,
) -> Result<(), Error> {
if target_positions.len() != self.allowable_errors.len() {
eprintln!("wait_until_each_error_condition condition size mismatch");
return Err(Error::LengthMismatch {
model: target_positions.len(),
input: self.allowable_errors.len(),
});
}
let dof = target_positions.len();
let mut is_reached = vec![false; dof];
const CHECK_UNIT_SEC: f64 = 0.01;
let check_unit_duration: Duration = Duration::from_secs_f64(CHECK_UNIT_SEC);
let num_repeat: i32 = ((self.timeout_sec + duration_sec) / CHECK_UNIT_SEC) as i32;
for _j in 0..num_repeat {
for i in 0..dof {
let cur = client.current_joint_positions()?[i];
let tar = target_positions[i];
if !is_reached[i] {
is_reached[i] = (tar - cur).abs() < self.allowable_errors[i];
}
}
if !is_reached.contains(&false) {
return Ok(());
}
tokio::time::sleep(check_unit_duration).await;
}
Err(Error::TimeoutWithDiff {
target: target_positions.to_vec(),
current: client.current_joint_positions()?,
is_reached,
})
}
}