flume/
select.rs

1//! Types that permit waiting upon multiple blocking operations using the [`Selector`] interface.
2
3use crate::*;
4use spin1::Mutex as Spinlock;
5use std::{any::Any, marker::PhantomData};
6
7#[cfg(feature = "eventual-fairness")]
8use nanorand::Rng;
9
10// A unique token corresponding to an event in a selector
11type Token = usize;
12
13struct SelectSignal(
14    thread::Thread,
15    Token,
16    AtomicBool,
17    Arc<Spinlock<VecDeque<Token>>>,
18);
19
20impl Signal for SelectSignal {
21    fn fire(&self) -> bool {
22        self.2.store(true, Ordering::SeqCst);
23        self.3.lock().push_back(self.1);
24        self.0.unpark();
25        false
26    }
27
28    fn as_any(&self) -> &(dyn Any + 'static) {
29        self
30    }
31    fn as_ptr(&self) -> *const () {
32        self as *const _ as *const ()
33    }
34}
35
36trait Selection<'a, T> {
37    fn init(&mut self) -> Option<T>;
38    fn poll(&mut self) -> Option<T>;
39    fn deinit(&mut self);
40}
41
42/// An error that may be emitted when attempting to wait for a value on a receiver.
43#[derive(Copy, Clone, Debug, PartialEq, Eq)]
44pub enum SelectError {
45    /// A timeout occurred when waiting on a `Selector`.
46    Timeout,
47}
48
49impl fmt::Display for SelectError {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self {
52            SelectError::Timeout => "timeout occurred".fmt(f),
53        }
54    }
55}
56
57impl std::error::Error for SelectError {}
58
59/// A type used to wait upon multiple blocking operations at once.
60///
61/// A [`Selector`] implements [`select`](https://en.wikipedia.org/wiki/Select_(Unix))-like behaviour,
62/// allowing a thread to wait upon the result of more than one operation at once.
63///
64/// # Examples
65/// ```
66/// let (tx0, rx0) = flume::unbounded();
67/// let (tx1, rx1) = flume::unbounded();
68///
69/// std::thread::spawn(move || {
70///     tx0.send(true).unwrap();
71///     tx1.send(42).unwrap();
72/// });
73///
74/// flume::Selector::new()
75///     .recv(&rx0, |b| println!("Received {:?}", b))
76///     .recv(&rx1, |n| println!("Received {:?}", n))
77///     .wait();
78/// ```
79pub struct Selector<'a, T: 'a> {
80    selections: Vec<Box<dyn Selection<'a, T> + 'a>>,
81    next_poll: usize,
82    signalled: Arc<Spinlock<VecDeque<Token>>>,
83    #[cfg(feature = "eventual-fairness")]
84    rng: nanorand::WyRand,
85    phantom: PhantomData<*const ()>,
86}
87
88impl<'a, T: 'a> Default for Selector<'a, T> {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl<'a, T: 'a> fmt::Debug for Selector<'a, T> {
95    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96        f.debug_struct("Selector").finish()
97    }
98}
99
100impl<'a, T> Selector<'a, T> {
101    /// Create a new selector.
102    pub fn new() -> Self {
103        Self {
104            selections: Vec::new(),
105            next_poll: 0,
106            signalled: Arc::default(),
107            phantom: PhantomData::default(),
108            #[cfg(feature = "eventual-fairness")]
109            rng: nanorand::WyRand::new(),
110        }
111    }
112
113    /// Add a send operation to the selector that sends the provided value.
114    ///
115    /// Once added, the selector can be used to run the provided handler function on completion of this operation.
116    pub fn send<U, F: FnMut(Result<(), SendError<U>>) -> T + 'a>(
117        mut self,
118        sender: &'a Sender<U>,
119        msg: U,
120        mapper: F,
121    ) -> Self {
122        struct SendSelection<'a, T, F, U> {
123            sender: &'a Sender<U>,
124            msg: Option<U>,
125            token: Token,
126            signalled: Arc<Spinlock<VecDeque<Token>>>,
127            hook: Option<Arc<Hook<U, SelectSignal>>>,
128            mapper: F,
129            phantom: PhantomData<T>,
130        }
131
132        impl<'a, T, F, U> Selection<'a, T> for SendSelection<'a, T, F, U>
133        where
134            F: FnMut(Result<(), SendError<U>>) -> T,
135        {
136            fn init(&mut self) -> Option<T> {
137                let token = self.token;
138                let signalled = self.signalled.clone();
139                let r = self.sender.shared.send(
140                    self.msg.take().unwrap(),
141                    true,
142                    |msg| {
143                        Hook::slot(
144                            Some(msg),
145                            SelectSignal(
146                                thread::current(),
147                                token,
148                                AtomicBool::new(false),
149                                signalled,
150                            ),
151                        )
152                    },
153                    // Always runs
154                    |h| {
155                        self.hook = Some(h);
156                        Ok(())
157                    },
158                );
159
160                if self.hook.is_none() {
161                    Some((self.mapper)(match r {
162                        Ok(()) => Ok(()),
163                        Err(TrySendTimeoutError::Disconnected(msg)) => Err(SendError(msg)),
164                        _ => unreachable!(),
165                    }))
166                } else {
167                    None
168                }
169            }
170
171            fn poll(&mut self) -> Option<T> {
172                let res = if self.sender.shared.is_disconnected() {
173                    // Check the hook one last time
174                    if let Some(msg) = self.hook.as_ref()?.try_take() {
175                        Err(SendError(msg))
176                    } else {
177                        Ok(())
178                    }
179                } else if self.hook.as_ref().unwrap().is_empty() {
180                    // The message was sent
181                    Ok(())
182                } else {
183                    return None;
184                };
185
186                Some((&mut self.mapper)(res))
187            }
188
189            fn deinit(&mut self) {
190                if let Some(hook) = self.hook.take() {
191                    // Remove hook
192                    let hook: Arc<Hook<U, dyn Signal>> = hook;
193                    wait_lock(&self.sender.shared.chan)
194                        .sending
195                        .as_mut()
196                        .unwrap()
197                        .1
198                        .retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
199                }
200            }
201        }
202
203        let token = self.selections.len();
204        self.selections.push(Box::new(SendSelection {
205            sender,
206            msg: Some(msg),
207            token,
208            signalled: self.signalled.clone(),
209            hook: None,
210            mapper,
211            phantom: Default::default(),
212        }));
213
214        self
215    }
216
217    /// Add a receive operation to the selector.
218    ///
219    /// Once added, the selector can be used to run the provided handler function on completion of this operation.
220    pub fn recv<U, F: FnMut(Result<U, RecvError>) -> T + 'a>(
221        mut self,
222        receiver: &'a Receiver<U>,
223        mapper: F,
224    ) -> Self {
225        struct RecvSelection<'a, T, F, U> {
226            receiver: &'a Receiver<U>,
227            token: Token,
228            signalled: Arc<Spinlock<VecDeque<Token>>>,
229            hook: Option<Arc<Hook<U, SelectSignal>>>,
230            mapper: F,
231            received: bool,
232            phantom: PhantomData<T>,
233        }
234
235        impl<'a, T, F, U> Selection<'a, T> for RecvSelection<'a, T, F, U>
236        where
237            F: FnMut(Result<U, RecvError>) -> T,
238        {
239            fn init(&mut self) -> Option<T> {
240                let token = self.token;
241                let signalled = self.signalled.clone();
242                let r = self.receiver.shared.recv(
243                    true,
244                    || {
245                        Hook::trigger(SelectSignal(
246                            thread::current(),
247                            token,
248                            AtomicBool::new(false),
249                            signalled,
250                        ))
251                    },
252                    // Always runs
253                    |h| {
254                        self.hook = Some(h);
255                        Err(TryRecvTimeoutError::Timeout)
256                    },
257                );
258
259                if self.hook.is_none() {
260                    Some((self.mapper)(match r {
261                        Ok(msg) => Ok(msg),
262                        Err(TryRecvTimeoutError::Disconnected) => Err(RecvError::Disconnected),
263                        _ => unreachable!(),
264                    }))
265                } else {
266                    None
267                }
268            }
269
270            fn poll(&mut self) -> Option<T> {
271                let res = if let Ok(msg) = self.receiver.try_recv() {
272                    self.received = true;
273                    Ok(msg)
274                } else if self.receiver.shared.is_disconnected() {
275                    Err(RecvError::Disconnected)
276                } else {
277                    return None;
278                };
279
280                Some((&mut self.mapper)(res))
281            }
282
283            fn deinit(&mut self) {
284                if let Some(hook) = self.hook.take() {
285                    // Remove hook
286                    let hook: Arc<Hook<U, dyn Signal>> = hook;
287                    wait_lock(&self.receiver.shared.chan)
288                        .waiting
289                        .retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
290                    // If we were woken, but never polled, wake up another
291                    if !self.received
292                        && hook
293                            .signal()
294                            .as_any()
295                            .downcast_ref::<SelectSignal>()
296                            .unwrap()
297                            .2
298                            .load(Ordering::SeqCst)
299                    {
300                        wait_lock(&self.receiver.shared.chan).try_wake_receiver_if_pending();
301                    }
302                }
303            }
304        }
305
306        let token = self.selections.len();
307        self.selections.push(Box::new(RecvSelection {
308            receiver,
309            token,
310            signalled: self.signalled.clone(),
311            hook: None,
312            mapper,
313            received: false,
314            phantom: Default::default(),
315        }));
316
317        self
318    }
319
320    fn wait_inner(mut self, deadline: Option<Instant>) -> Option<T> {
321        #[cfg(feature = "eventual-fairness")]
322        {
323            self.next_poll = self.rng.generate_range(0..self.selections.len());
324        }
325
326        let res = 'outer: loop {
327            // Init signals
328            for _ in 0..self.selections.len() {
329                if let Some(val) = self.selections[self.next_poll].init() {
330                    break 'outer Some(val);
331                }
332                self.next_poll = (self.next_poll + 1) % self.selections.len();
333            }
334
335            // Speculatively poll
336            if let Some(msg) = self.poll() {
337                break 'outer Some(msg);
338            }
339
340            loop {
341                if let Some(deadline) = deadline {
342                    if let Some(dur) = deadline.checked_duration_since(Instant::now()) {
343                        thread::park_timeout(dur);
344                    }
345                } else {
346                    thread::park();
347                }
348
349                if deadline.map(|d| Instant::now() >= d).unwrap_or(false) {
350                    break 'outer self.poll();
351                }
352
353                let token = if let Some(token) = self.signalled.lock().pop_front() {
354                    token
355                } else {
356                    // Spurious wakeup, park again
357                    continue;
358                };
359
360                // Attempt to receive a message
361                if let Some(msg) = self.selections[token].poll() {
362                    break 'outer Some(msg);
363                }
364            }
365        };
366
367        // Deinit signals
368        for s in &mut self.selections {
369            s.deinit();
370        }
371
372        res
373    }
374
375    fn poll(&mut self) -> Option<T> {
376        for _ in 0..self.selections.len() {
377            if let Some(val) = self.selections[self.next_poll].poll() {
378                return Some(val);
379            }
380            self.next_poll = (self.next_poll + 1) % self.selections.len();
381        }
382        None
383    }
384
385    /// Wait until one of the events associated with this [`Selector`] has completed. If the `eventual-fairness`
386    /// feature flag is enabled, this method is fair and will handle a random event of those that are ready.
387    pub fn wait(self) -> T {
388        self.wait_inner(None).unwrap()
389    }
390
391    /// Wait until one of the events associated with this [`Selector`] has completed or the timeout has expired. If the
392    /// `eventual-fairness` feature flag is enabled, this method is fair and will handle a random event of those that
393    /// are ready.
394    pub fn wait_timeout(self, dur: Duration) -> Result<T, SelectError> {
395        self.wait_inner(Some(Instant::now() + dur))
396            .ok_or(SelectError::Timeout)
397    }
398
399    /// Wait until one of the events associated with this [`Selector`] has completed or the deadline has been reached.
400    /// If the `eventual-fairness` feature flag is enabled, this method is fair and will handle a random event of those
401    /// that are ready.
402    pub fn wait_deadline(self, deadline: Instant) -> Result<T, SelectError> {
403        self.wait_inner(Some(deadline)).ok_or(SelectError::Timeout)
404    }
405}