rosrust/tcpros/util/
streamfork.rs

1use crate::util::lossy_channel::{lossy_channel, LossyReceiver, LossySender};
2use crate::util::FAILED_TO_LOCK;
3use crossbeam::channel::{self, unbounded, Receiver, Sender};
4use std::io::Write;
5use std::sync::{Arc, Mutex};
6use std::thread;
7
8pub fn fork<T: Write + Send + 'static>(queue_size: usize) -> (TargetList<T>, DataStream) {
9    let (streams_sender, streams) = unbounded();
10    let (data_sender, data) = lossy_channel(queue_size);
11
12    let mut fork_thread = ForkThread::new();
13    let target_names = Arc::clone(&fork_thread.target_names);
14
15    thread::spawn(move || fork_thread.run(&streams, &data));
16
17    (
18        TargetList(streams_sender),
19        DataStream {
20            sender: data_sender,
21            target_names,
22        },
23    )
24}
25
26struct ForkThread<T: Write + Send + 'static> {
27    targets: Vec<SubscriberInfo<T>>,
28    target_names: Arc<Mutex<TargetNames>>,
29}
30
31impl<T: Write + Send + 'static> ForkThread<T> {
32    pub fn new() -> Self {
33        Self {
34            targets: vec![],
35            target_names: Arc::new(Mutex::new(TargetNames {
36                targets: Vec::new(),
37            })),
38        }
39    }
40
41    fn publish_buffer_and_prune_targets(&mut self, buffer: &[u8]) {
42        let mut dropped_targets = vec![];
43        for (idx, target) in self.targets.iter_mut().enumerate() {
44            if target.stream.write_all(buffer).is_err() {
45                dropped_targets.push(idx);
46            }
47        }
48
49        if !dropped_targets.is_empty() {
50            // We reverse the order, to remove bigger indices first.
51            for idx in dropped_targets.into_iter().rev() {
52                self.targets.swap_remove(idx);
53            }
54            self.update_target_names();
55        }
56    }
57
58    fn add_target(&mut self, target: SubscriberInfo<T>) {
59        self.targets.push(target);
60        self.update_target_names();
61    }
62
63    fn update_target_names(&self) {
64        let targets = self
65            .targets
66            .iter()
67            .map(|target| target.caller_id.clone())
68            .collect();
69        *self.target_names.lock().expect(FAILED_TO_LOCK) = TargetNames { targets };
70    }
71
72    fn step(
73        &mut self,
74        streams: &Receiver<SubscriberInfo<T>>,
75        data: &LossyReceiver<Arc<Vec<u8>>>,
76    ) -> Result<(), channel::RecvError> {
77        channel::select! {
78            recv(data.kill_rx.kill_rx) -> msg => {
79                return msg.and(Err(channel::RecvError));
80            }
81            recv(data.data_rx) -> msg => {
82                self.publish_buffer_and_prune_targets(&msg?);
83            }
84            recv(streams) -> target => {
85                self.add_target(target?);
86            }
87        }
88        Ok(())
89    }
90
91    pub fn run(
92        &mut self,
93        streams: &Receiver<SubscriberInfo<T>>,
94        data: &LossyReceiver<Arc<Vec<u8>>>,
95    ) {
96        while self.step(streams, data).is_ok() {}
97    }
98}
99
100pub type ForkResult = Result<(), ()>;
101
102pub struct TargetList<T: Write + Send + 'static>(Sender<SubscriberInfo<T>>);
103
104impl<T: Write + Send + 'static> TargetList<T> {
105    pub fn add(&self, caller_id: String, stream: T) -> ForkResult {
106        self.0
107            .send(SubscriberInfo { caller_id, stream })
108            .or(Err(()))
109    }
110}
111
112struct SubscriberInfo<T> {
113    caller_id: String,
114    stream: T,
115}
116
117#[derive(Clone)]
118pub struct DataStream {
119    sender: LossySender<Arc<Vec<u8>>>,
120    target_names: Arc<Mutex<TargetNames>>,
121}
122
123impl DataStream {
124    pub fn send(&self, data: Arc<Vec<u8>>) -> ForkResult {
125        self.sender.try_send(data).or(Err(()))
126    }
127
128    #[inline]
129    pub fn target_count(&self) -> usize {
130        self.target_names.lock().expect(FAILED_TO_LOCK).count()
131    }
132
133    #[inline]
134    pub fn target_names(&self) -> Vec<String> {
135        self.target_names.lock().expect(FAILED_TO_LOCK).names()
136    }
137
138    #[inline]
139    pub fn set_queue_size(&self, queue_size: usize) {
140        self.sender.set_queue_size(queue_size);
141    }
142
143    #[inline]
144    pub fn set_queue_size_max(&self, queue_size: usize) {
145        self.sender.set_queue_size_max(queue_size);
146    }
147}
148
149#[derive(Debug)]
150pub struct TargetNames {
151    targets: Vec<String>,
152}
153
154impl TargetNames {
155    #[inline]
156    pub fn count(&self) -> usize {
157        self.targets.len()
158    }
159
160    #[inline]
161    pub fn names(&self) -> Vec<String> {
162        self.targets.clone()
163    }
164}