flume/
async.rs

1//! Futures and other types that allow asynchronous interaction with channels.
2
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll, Waker},
7    any::Any,
8    ops::Deref,
9};
10use std::fmt::{Debug, Formatter};
11use crate::*;
12use futures_core::{stream::{Stream, FusedStream}, future::FusedFuture};
13use futures_sink::Sink;
14use spin1::Mutex as Spinlock;
15
16struct AsyncSignal {
17    waker: Spinlock<Waker>,
18    woken: AtomicBool,
19    stream: bool,
20}
21
22impl AsyncSignal {
23    fn new(cx: &Context, stream: bool) -> Self {
24        AsyncSignal {
25            waker: Spinlock::new(cx.waker().clone()),
26            woken: AtomicBool::new(false),
27            stream,
28        }
29    }
30}
31
32impl Signal for AsyncSignal {
33    fn fire(&self) -> bool {
34        self.woken.store(true, Ordering::SeqCst);
35        self.waker.lock().wake_by_ref();
36        self.stream
37    }
38
39    fn as_any(&self) -> &(dyn Any + 'static) { self }
40    fn as_ptr(&self) -> *const () { self as *const _ as *const () }
41}
42
43impl<T> Hook<T, AsyncSignal> {
44    // Update the hook to point to the given Waker.
45    // Returns whether the hook has been previously awakened
46    fn update_waker(&self, cx_waker: &Waker) -> bool {
47        let mut waker = self.1.waker.lock();
48        let woken = self.1.woken.load(Ordering::SeqCst);
49        if !waker.will_wake(cx_waker) {
50            *waker = cx_waker.clone();
51
52            // Avoid the edge case where the waker was woken just before the wakers were
53            // swapped.
54            if woken {
55                cx_waker.wake_by_ref();
56            }
57        }
58        woken
59    }
60}
61
62#[derive(Clone)]
63enum OwnedOrRef<'a, T> {
64    Owned(T),
65    Ref(&'a T),
66}
67
68impl<'a, T> Deref for OwnedOrRef<'a, T> {
69    type Target = T;
70
71    fn deref(&self) -> &T {
72        match self {
73            OwnedOrRef::Owned(arc) => &arc,
74            OwnedOrRef::Ref(r) => r,
75        }
76    }
77}
78
79impl<T> Sender<T> {
80    /// Asynchronously send a value into the channel, returning an error if all receivers have been
81    /// dropped. If the channel is bounded and is full, the returned future will yield to the async
82    /// runtime.
83    ///
84    /// In the current implementation, the returned future will not yield to the async runtime if the
85    /// channel is unbounded. This may change in later versions.
86    pub fn send_async(&self, item: T) -> SendFut<T> {
87        SendFut {
88            sender: OwnedOrRef::Ref(&self),
89            hook: Some(SendState::NotYetSent(item)),
90        }
91    }
92
93    /// Convert this sender into a future that asynchronously sends a single message into the channel,
94    /// returning an error if all receivers have been dropped. If the channel is bounded and is full,
95    /// this future will yield to the async runtime.
96    ///
97    /// In the current implementation, the returned future will not yield to the async runtime if the
98    /// channel is unbounded. This may change in later versions.
99    pub fn into_send_async<'a>(self, item: T) -> SendFut<'a, T> {
100        SendFut {
101            sender: OwnedOrRef::Owned(self),
102            hook: Some(SendState::NotYetSent(item)),
103        }
104    }
105
106    /// Create an asynchronous sink that uses this sender to asynchronously send messages into the
107    /// channel. The sender will continue to be usable after the sink has been dropped.
108    ///
109    /// In the current implementation, the returned sink will not yield to the async runtime if the
110    /// channel is unbounded. This may change in later versions.
111    pub fn sink(&self) -> SendSink<'_, T> {
112        SendSink(SendFut {
113            sender: OwnedOrRef::Ref(&self),
114            hook: None,
115        })
116    }
117
118    /// Convert this sender into a sink that allows asynchronously sending messages into the channel.
119    ///
120    /// In the current implementation, the returned sink will not yield to the async runtime if the
121    /// channel is unbounded. This may change in later versions.
122    pub fn into_sink<'a>(self) -> SendSink<'a, T> {
123        SendSink(SendFut {
124            sender: OwnedOrRef::Owned(self),
125            hook: None,
126        })
127    }
128}
129
130enum SendState<T> {
131    NotYetSent(T),
132    QueuedItem(Arc<Hook<T, AsyncSignal>>),
133}
134
135/// A future that sends a value into a channel.
136///
137/// Can be created via [`Sender::send_async`] or [`Sender::into_send_async`].
138#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
139pub struct SendFut<'a, T> {
140    sender: OwnedOrRef<'a, Sender<T>>,
141    // Only none after dropping
142    hook: Option<SendState<T>>,
143}
144
145impl<'a, T> Debug for SendFut<'a, T> {
146    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
147        f.debug_struct("SendFut").finish()
148    }
149}
150
151impl<T> std::marker::Unpin for SendFut<'_, T> {}
152
153impl<'a, T> SendFut<'a, T> {
154    /// Reset the hook, clearing it and removing it from the waiting sender's queue. This is called
155    /// on drop and just before `start_send` in the `Sink` implementation.
156    fn reset_hook(&mut self) {
157        if let Some(SendState::QueuedItem(hook)) = self.hook.take() {
158            let hook: Arc<Hook<T, dyn Signal>> = hook;
159            wait_lock(&self.sender.shared.chan).sending
160                .as_mut()
161                .unwrap().1
162                .retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
163        }
164    }
165
166    /// See [`Sender::is_disconnected`].
167    pub fn is_disconnected(&self) -> bool {
168        self.sender.is_disconnected()
169    }
170
171    /// See [`Sender::is_empty`].
172    pub fn is_empty(&self) -> bool {
173        self.sender.is_empty()
174    }
175
176    /// See [`Sender::is_full`].
177    pub fn is_full(&self) -> bool {
178        self.sender.is_full()
179    }
180
181    /// See [`Sender::len`].
182    pub fn len(&self) -> usize {
183        self.sender.len()
184    }
185
186    /// See [`Sender::capacity`].
187    pub fn capacity(&self) -> Option<usize> {
188        self.sender.capacity()
189    }
190}
191
192impl<'a, T> Drop for SendFut<'a, T> {
193    fn drop(&mut self) {
194        self.reset_hook()
195    }
196}
197
198
199impl<'a, T> Future for SendFut<'a, T> {
200    type Output = Result<(), SendError<T>>;
201
202    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
203        if let Some(SendState::QueuedItem(hook)) = self.hook.as_ref() {
204            if hook.is_empty() {
205                Poll::Ready(Ok(()))
206            } else if self.sender.shared.is_disconnected() {
207                let item = hook.try_take();
208                self.hook = None;
209                match item {
210                    Some(item) => Poll::Ready(Err(SendError(item))),
211                    None => Poll::Ready(Ok(())),
212                }
213            } else {
214                hook.update_waker(cx.waker());
215                Poll::Pending
216            }
217        } else if let Some(SendState::NotYetSent(item)) = self.hook.take() {
218            let this = self.get_mut();
219            let (shared, this_hook) = (&this.sender.shared, &mut this.hook);
220
221            shared.send(
222                // item
223                item,
224                // should_block
225                true,
226                // make_signal
227                |msg| Hook::slot(Some(msg), AsyncSignal::new(cx, false)),
228                // do_block
229                |hook| {
230                    *this_hook = Some(SendState::QueuedItem(hook));
231                    Poll::Pending
232                },
233            )
234                .map(|r| r.map_err(|err| match err {
235                    TrySendTimeoutError::Disconnected(msg) => SendError(msg),
236                    _ => unreachable!(),
237                }))
238        } else { // Nothing to do
239            Poll::Ready(Ok(()))
240        }
241    }
242}
243
244impl<'a, T> FusedFuture for SendFut<'a, T> {
245    fn is_terminated(&self) -> bool {
246        self.sender.shared.is_disconnected()
247    }
248}
249
250/// A sink that allows sending values into a channel.
251///
252/// Can be created via [`Sender::sink`] or [`Sender::into_sink`].
253pub struct SendSink<'a, T>(SendFut<'a, T>);
254
255impl<'a, T> SendSink<'a, T> {
256    /// Returns a clone of a sending half of the channel of this sink.
257    pub fn sender(&self) -> &Sender<T> {
258        &self.0.sender
259    }
260
261    /// See [`Sender::is_disconnected`].
262    pub fn is_disconnected(&self) -> bool {
263        self.0.is_disconnected()
264    }
265
266    /// See [`Sender::is_empty`].
267    pub fn is_empty(&self) -> bool {
268        self.0.is_empty()
269    }
270
271    /// See [`Sender::is_full`].
272    pub fn is_full(&self) -> bool {
273        self.0.is_full()
274    }
275
276    /// See [`Sender::len`].
277    pub fn len(&self) -> usize {
278        self.0.len()
279    }
280
281    /// See [`Sender::capacity`].
282    pub fn capacity(&self) -> Option<usize> {
283        self.0.capacity()
284    }
285
286    /// Returns whether the SendSinks are belong to the same channel.
287    pub fn same_channel(&self, other: &Self) -> bool {
288        self.sender().same_channel(other.sender())
289    }
290}
291
292impl<'a, T> Debug for SendSink<'a, T> {
293    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
294        f.debug_struct("SendSink").finish()
295    }
296}
297
298impl<'a, T> Sink<T> for SendSink<'a, T> {
299    type Error = SendError<T>;
300
301    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
302        Pin::new(&mut self.0).poll(cx)
303    }
304
305    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
306        self.0.reset_hook();
307        self.0.hook = Some(SendState::NotYetSent(item));
308
309        Ok(())
310    }
311
312    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
313        Pin::new(&mut self.0).poll(cx) // TODO: A different strategy here?
314    }
315
316    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
317        Pin::new(&mut self.0).poll(cx) // TODO: A different strategy here?
318    }
319}
320
321impl<'a, T> Clone for SendSink<'a, T> {
322    fn clone(&self) -> SendSink<'a, T> {
323        SendSink(SendFut {
324            sender: self.0.sender.clone(),
325            hook: None,
326        })
327    }
328}
329
330impl<T> Receiver<T> {
331    /// Asynchronously receive a value from the channel, returning an error if all senders have been
332    /// dropped. If the channel is empty, the returned future will yield to the async runtime.
333    pub fn recv_async(&self) -> RecvFut<'_, T> {
334        RecvFut::new(OwnedOrRef::Ref(self))
335    }
336
337    /// Convert this receiver into a future that asynchronously receives a single message from the
338    /// channel, returning an error if all senders have been dropped. If the channel is empty, this
339    /// future will yield to the async runtime.
340    pub fn into_recv_async<'a>(self) -> RecvFut<'a, T> {
341        RecvFut::new(OwnedOrRef::Owned(self))
342    }
343
344    /// Create an asynchronous stream that uses this receiver to asynchronously receive messages
345    /// from the channel. The receiver will continue to be usable after the stream has been dropped.
346    pub fn stream(&self) -> RecvStream<'_, T> {
347        RecvStream(RecvFut::new(OwnedOrRef::Ref(self)))
348    }
349
350    /// Convert this receiver into a stream that allows asynchronously receiving messages from the channel.
351    pub fn into_stream<'a>(self) -> RecvStream<'a, T> {
352        RecvStream(RecvFut::new(OwnedOrRef::Owned(self)))
353    }
354}
355
356/// A future which allows asynchronously receiving a message.
357///
358/// Can be created via [`Receiver::recv_async`] or [`Receiver::into_recv_async`].
359#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
360pub struct RecvFut<'a, T> {
361    receiver: OwnedOrRef<'a, Receiver<T>>,
362    hook: Option<Arc<Hook<T, AsyncSignal>>>,
363}
364
365impl<'a, T> RecvFut<'a, T> {
366    fn new(receiver: OwnedOrRef<'a, Receiver<T>>) -> Self {
367        Self {
368            receiver,
369            hook: None,
370        }
371    }
372
373    /// Reset the hook, clearing it and removing it from the waiting receivers queue and waking
374    /// another receiver if this receiver has been woken, so as not to cause any missed wakeups.
375    /// This is called on drop and after a new item is received in `Stream::poll_next`.
376    fn reset_hook(&mut self) {
377        if let Some(hook) = self.hook.take() {
378            let hook: Arc<Hook<T, dyn Signal>> = hook;
379            let mut chan = wait_lock(&self.receiver.shared.chan);
380            // We'd like to use `Arc::ptr_eq` here but it doesn't seem to work consistently with wide pointers?
381            chan.waiting.retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
382            if hook.signal().as_any().downcast_ref::<AsyncSignal>().unwrap().woken.load(Ordering::SeqCst) {
383                // If this signal has been fired, but we're being dropped (and so not listening to it),
384                // pass the signal on to another receiver
385                chan.try_wake_receiver_if_pending();
386            }
387        }
388    }
389
390    fn poll_inner(
391        self: Pin<&mut Self>,
392        cx: &mut Context,
393        stream: bool,
394    ) -> Poll<Result<T, RecvError>> {
395        if self.hook.is_some() {
396            match self.receiver.shared.recv_sync(None) {
397                Ok(msg) => return Poll::Ready(Ok(msg)),
398                Err(TryRecvTimeoutError::Disconnected) => {
399                    return Poll::Ready(Err(RecvError::Disconnected))
400                }
401                _ => (),
402            }
403
404            let hook = self.hook.as_ref().map(Arc::clone).unwrap();
405            if hook.update_waker(cx.waker()) {
406                // If the previous hook was awakened, we need to insert it back to the
407                // queue, otherwise, it remains valid.
408                wait_lock(&self.receiver.shared.chan)
409                    .waiting
410                    .push_back(hook);
411            }
412            // To avoid a missed wakeup, re-check disconnect status here because the channel might have
413            // gotten shut down before we had a chance to push our hook
414            if self.receiver.shared.is_disconnected() {
415                // And now, to avoid a race condition between the first recv attempt and the disconnect check we
416                // just performed, attempt to recv again just in case we missed something.
417                Poll::Ready(
418                    self.receiver
419                        .shared
420                        .recv_sync(None)
421                        .map(Ok)
422                        .unwrap_or(Err(RecvError::Disconnected)),
423                )
424            } else {
425                Poll::Pending
426            }
427        } else {
428            let mut_self = self.get_mut();
429            let (shared, this_hook) = (&mut_self.receiver.shared, &mut mut_self.hook);
430
431            shared.recv(
432                // should_block
433                true,
434                // make_signal
435                || Hook::trigger(AsyncSignal::new(cx, stream)),
436                // do_block
437                |hook| {
438                    *this_hook = Some(hook);
439                    Poll::Pending
440                },
441            )
442                .map(|r| r.map_err(|err| match err {
443                    TryRecvTimeoutError::Disconnected => RecvError::Disconnected,
444                    _ => unreachable!(),
445                }))
446        }
447    }
448
449    /// See [`Receiver::is_disconnected`].
450    pub fn is_disconnected(&self) -> bool {
451        self.receiver.is_disconnected()
452    }
453
454    /// See [`Receiver::is_empty`].
455    pub fn is_empty(&self) -> bool {
456        self.receiver.is_empty()
457    }
458
459    /// See [`Receiver::is_full`].
460    pub fn is_full(&self) -> bool {
461        self.receiver.is_full()
462    }
463
464    /// See [`Receiver::len`].
465    pub fn len(&self) -> usize {
466        self.receiver.len()
467    }
468
469    /// See [`Receiver::capacity`].
470    pub fn capacity(&self) -> Option<usize> {
471        self.receiver.capacity()
472    }
473}
474
475impl<'a, T> Debug for RecvFut<'a, T> {
476    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
477        f.debug_struct("RecvFut").finish()
478    }
479}
480
481impl<'a, T> Drop for RecvFut<'a, T> {
482    fn drop(&mut self) {
483        self.reset_hook();
484    }
485}
486
487impl<'a, T> Future for RecvFut<'a, T> {
488    type Output = Result<T, RecvError>;
489
490    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
491        self.poll_inner(cx, false) // stream = false
492    }
493}
494
495impl<'a, T> FusedFuture for RecvFut<'a, T> {
496    fn is_terminated(&self) -> bool {
497        self.receiver.shared.is_disconnected() && self.receiver.shared.is_empty()
498    }
499}
500
501/// A stream which allows asynchronously receiving messages.
502///
503/// Can be created via [`Receiver::stream`] or [`Receiver::into_stream`].
504pub struct RecvStream<'a, T>(RecvFut<'a, T>);
505
506impl<'a, T> RecvStream<'a, T> {
507    /// See [`Receiver::is_disconnected`].
508    pub fn is_disconnected(&self) -> bool {
509        self.0.is_disconnected()
510    }
511
512    /// See [`Receiver::is_empty`].
513    pub fn is_empty(&self) -> bool {
514        self.0.is_empty()
515    }
516
517    /// See [`Receiver::is_full`].
518    pub fn is_full(&self) -> bool {
519        self.0.is_full()
520    }
521
522    /// See [`Receiver::len`].
523    pub fn len(&self) -> usize {
524        self.0.len()
525    }
526
527    /// See [`Receiver::capacity`].
528    pub fn capacity(&self) -> Option<usize> {
529        self.0.capacity()
530    }
531
532    /// Returns whether the SendSinks are belong to the same channel.
533    pub fn same_channel(&self, other: &Self) -> bool {
534        self.0.receiver.same_channel(&*other.0.receiver)
535    }
536}
537
538impl<'a, T> Debug for RecvStream<'a, T> {
539    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
540        f.debug_struct("RecvStream").finish()
541    }
542}
543
544impl<'a, T> Clone for RecvStream<'a, T> {
545    fn clone(&self) -> RecvStream<'a, T> {
546        RecvStream(RecvFut::new(self.0.receiver.clone()))
547    }
548}
549
550impl<'a, T> Stream for RecvStream<'a, T> {
551    type Item = T;
552
553    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
554        match Pin::new(&mut self.0).poll_inner(cx, true) { // stream = true
555            Poll::Pending => Poll::Pending,
556            Poll::Ready(item) => {
557                self.0.reset_hook();
558                Poll::Ready(item.ok())
559            }
560        }
561    }
562}
563
564impl<'a, T> FusedStream for RecvStream<'a, T> {
565    fn is_terminated(&self) -> bool {
566        self.0.is_terminated()
567    }
568}