1use crate::*;
4use spin1::Mutex as Spinlock;
5use std::{any::Any, marker::PhantomData};
6
7#[cfg(feature = "eventual-fairness")]
8use nanorand::Rng;
9
10type Token = usize;
12
13struct SelectSignal(
14 thread::Thread,
15 Token,
16 AtomicBool,
17 Arc<Spinlock<VecDeque<Token>>>,
18);
19
20impl Signal for SelectSignal {
21 fn fire(&self) -> bool {
22 self.2.store(true, Ordering::SeqCst);
23 self.3.lock().push_back(self.1);
24 self.0.unpark();
25 false
26 }
27
28 fn as_any(&self) -> &(dyn Any + 'static) {
29 self
30 }
31 fn as_ptr(&self) -> *const () {
32 self as *const _ as *const ()
33 }
34}
35
36trait Selection<'a, T> {
37 fn init(&mut self) -> Option<T>;
38 fn poll(&mut self) -> Option<T>;
39 fn deinit(&mut self);
40}
41
42#[derive(Copy, Clone, Debug, PartialEq, Eq)]
44pub enum SelectError {
45 Timeout,
47}
48
49impl fmt::Display for SelectError {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 match self {
52 SelectError::Timeout => "timeout occurred".fmt(f),
53 }
54 }
55}
56
57impl std::error::Error for SelectError {}
58
59pub struct Selector<'a, T: 'a> {
80 selections: Vec<Box<dyn Selection<'a, T> + 'a>>,
81 next_poll: usize,
82 signalled: Arc<Spinlock<VecDeque<Token>>>,
83 #[cfg(feature = "eventual-fairness")]
84 rng: nanorand::WyRand,
85 phantom: PhantomData<*const ()>,
86}
87
88impl<'a, T: 'a> Default for Selector<'a, T> {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl<'a, T: 'a> fmt::Debug for Selector<'a, T> {
95 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96 f.debug_struct("Selector").finish()
97 }
98}
99
100impl<'a, T> Selector<'a, T> {
101 pub fn new() -> Self {
103 Self {
104 selections: Vec::new(),
105 next_poll: 0,
106 signalled: Arc::default(),
107 phantom: PhantomData::default(),
108 #[cfg(feature = "eventual-fairness")]
109 rng: nanorand::WyRand::new(),
110 }
111 }
112
113 pub fn send<U, F: FnMut(Result<(), SendError<U>>) -> T + 'a>(
117 mut self,
118 sender: &'a Sender<U>,
119 msg: U,
120 mapper: F,
121 ) -> Self {
122 struct SendSelection<'a, T, F, U> {
123 sender: &'a Sender<U>,
124 msg: Option<U>,
125 token: Token,
126 signalled: Arc<Spinlock<VecDeque<Token>>>,
127 hook: Option<Arc<Hook<U, SelectSignal>>>,
128 mapper: F,
129 phantom: PhantomData<T>,
130 }
131
132 impl<'a, T, F, U> Selection<'a, T> for SendSelection<'a, T, F, U>
133 where
134 F: FnMut(Result<(), SendError<U>>) -> T,
135 {
136 fn init(&mut self) -> Option<T> {
137 let token = self.token;
138 let signalled = self.signalled.clone();
139 let r = self.sender.shared.send(
140 self.msg.take().unwrap(),
141 true,
142 |msg| {
143 Hook::slot(
144 Some(msg),
145 SelectSignal(
146 thread::current(),
147 token,
148 AtomicBool::new(false),
149 signalled,
150 ),
151 )
152 },
153 |h| {
155 self.hook = Some(h);
156 Ok(())
157 },
158 );
159
160 if self.hook.is_none() {
161 Some((self.mapper)(match r {
162 Ok(()) => Ok(()),
163 Err(TrySendTimeoutError::Disconnected(msg)) => Err(SendError(msg)),
164 _ => unreachable!(),
165 }))
166 } else {
167 None
168 }
169 }
170
171 fn poll(&mut self) -> Option<T> {
172 let res = if self.sender.shared.is_disconnected() {
173 if let Some(msg) = self.hook.as_ref()?.try_take() {
175 Err(SendError(msg))
176 } else {
177 Ok(())
178 }
179 } else if self.hook.as_ref().unwrap().is_empty() {
180 Ok(())
182 } else {
183 return None;
184 };
185
186 Some((&mut self.mapper)(res))
187 }
188
189 fn deinit(&mut self) {
190 if let Some(hook) = self.hook.take() {
191 let hook: Arc<Hook<U, dyn Signal>> = hook;
193 wait_lock(&self.sender.shared.chan)
194 .sending
195 .as_mut()
196 .unwrap()
197 .1
198 .retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
199 }
200 }
201 }
202
203 let token = self.selections.len();
204 self.selections.push(Box::new(SendSelection {
205 sender,
206 msg: Some(msg),
207 token,
208 signalled: self.signalled.clone(),
209 hook: None,
210 mapper,
211 phantom: Default::default(),
212 }));
213
214 self
215 }
216
217 pub fn recv<U, F: FnMut(Result<U, RecvError>) -> T + 'a>(
221 mut self,
222 receiver: &'a Receiver<U>,
223 mapper: F,
224 ) -> Self {
225 struct RecvSelection<'a, T, F, U> {
226 receiver: &'a Receiver<U>,
227 token: Token,
228 signalled: Arc<Spinlock<VecDeque<Token>>>,
229 hook: Option<Arc<Hook<U, SelectSignal>>>,
230 mapper: F,
231 received: bool,
232 phantom: PhantomData<T>,
233 }
234
235 impl<'a, T, F, U> Selection<'a, T> for RecvSelection<'a, T, F, U>
236 where
237 F: FnMut(Result<U, RecvError>) -> T,
238 {
239 fn init(&mut self) -> Option<T> {
240 let token = self.token;
241 let signalled = self.signalled.clone();
242 let r = self.receiver.shared.recv(
243 true,
244 || {
245 Hook::trigger(SelectSignal(
246 thread::current(),
247 token,
248 AtomicBool::new(false),
249 signalled,
250 ))
251 },
252 |h| {
254 self.hook = Some(h);
255 Err(TryRecvTimeoutError::Timeout)
256 },
257 );
258
259 if self.hook.is_none() {
260 Some((self.mapper)(match r {
261 Ok(msg) => Ok(msg),
262 Err(TryRecvTimeoutError::Disconnected) => Err(RecvError::Disconnected),
263 _ => unreachable!(),
264 }))
265 } else {
266 None
267 }
268 }
269
270 fn poll(&mut self) -> Option<T> {
271 let res = if let Ok(msg) = self.receiver.try_recv() {
272 self.received = true;
273 Ok(msg)
274 } else if self.receiver.shared.is_disconnected() {
275 Err(RecvError::Disconnected)
276 } else {
277 return None;
278 };
279
280 Some((&mut self.mapper)(res))
281 }
282
283 fn deinit(&mut self) {
284 if let Some(hook) = self.hook.take() {
285 let hook: Arc<Hook<U, dyn Signal>> = hook;
287 wait_lock(&self.receiver.shared.chan)
288 .waiting
289 .retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
290 if !self.received
292 && hook
293 .signal()
294 .as_any()
295 .downcast_ref::<SelectSignal>()
296 .unwrap()
297 .2
298 .load(Ordering::SeqCst)
299 {
300 wait_lock(&self.receiver.shared.chan).try_wake_receiver_if_pending();
301 }
302 }
303 }
304 }
305
306 let token = self.selections.len();
307 self.selections.push(Box::new(RecvSelection {
308 receiver,
309 token,
310 signalled: self.signalled.clone(),
311 hook: None,
312 mapper,
313 received: false,
314 phantom: Default::default(),
315 }));
316
317 self
318 }
319
320 fn wait_inner(mut self, deadline: Option<Instant>) -> Option<T> {
321 #[cfg(feature = "eventual-fairness")]
322 {
323 self.next_poll = self.rng.generate_range(0..self.selections.len());
324 }
325
326 let res = 'outer: loop {
327 for _ in 0..self.selections.len() {
329 if let Some(val) = self.selections[self.next_poll].init() {
330 break 'outer Some(val);
331 }
332 self.next_poll = (self.next_poll + 1) % self.selections.len();
333 }
334
335 if let Some(msg) = self.poll() {
337 break 'outer Some(msg);
338 }
339
340 loop {
341 if let Some(deadline) = deadline {
342 if let Some(dur) = deadline.checked_duration_since(Instant::now()) {
343 thread::park_timeout(dur);
344 }
345 } else {
346 thread::park();
347 }
348
349 if deadline.map(|d| Instant::now() >= d).unwrap_or(false) {
350 break 'outer self.poll();
351 }
352
353 let token = if let Some(token) = self.signalled.lock().pop_front() {
354 token
355 } else {
356 continue;
358 };
359
360 if let Some(msg) = self.selections[token].poll() {
362 break 'outer Some(msg);
363 }
364 }
365 };
366
367 for s in &mut self.selections {
369 s.deinit();
370 }
371
372 res
373 }
374
375 fn poll(&mut self) -> Option<T> {
376 for _ in 0..self.selections.len() {
377 if let Some(val) = self.selections[self.next_poll].poll() {
378 return Some(val);
379 }
380 self.next_poll = (self.next_poll + 1) % self.selections.len();
381 }
382 None
383 }
384
385 pub fn wait(self) -> T {
388 self.wait_inner(None).unwrap()
389 }
390
391 pub fn wait_timeout(self, dur: Duration) -> Result<T, SelectError> {
395 self.wait_inner(Some(Instant::now() + dur))
396 .ok_or(SelectError::Timeout)
397 }
398
399 pub fn wait_deadline(self, deadline: Instant) -> Result<T, SelectError> {
403 self.wait_inner(Some(deadline)).ok_or(SelectError::Timeout)
404 }
405}