brotli/enc/
worker_pool.rs

1#![cfg(feature = "std")]
2use core::mem;
3use std;
4
5use alloc::{Allocator, SliceWrapper};
6use enc::backward_references::UnionHasher;
7use enc::fixed_queue::{FixedQueue, MAX_THREADS};
8use enc::threading::{
9    BatchSpawnableLite, BrotliEncoderThreadError, CompressMulti, CompressionThreadResult,
10    InternalOwned, InternalSendAlloc, Joinable, Owned, SendAlloc,
11};
12use enc::BrotliAlloc;
13use enc::BrotliEncoderParams;
14use std::sync::{Arc, Condvar, Mutex};
15// in-place thread create
16
17use std::sync::RwLock;
18
19struct JobReply<T: Send + 'static> {
20    result: T,
21    work_id: u64,
22}
23
24struct JobRequest<
25    ReturnValue: Send + 'static,
26    ExtraInput: Send + 'static,
27    Alloc: BrotliAlloc + Send + 'static,
28    U: Send + 'static + Sync,
29> {
30    func: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
31    extra_input: ExtraInput,
32    index: usize,
33    thread_size: usize,
34    data: Arc<RwLock<U>>,
35    alloc: Alloc,
36    work_id: u64,
37}
38
39struct WorkQueue<
40    ReturnValue: Send + 'static,
41    ExtraInput: Send + 'static,
42    Alloc: BrotliAlloc + Send + 'static,
43    U: Send + 'static + Sync,
44> {
45    jobs: FixedQueue<JobRequest<ReturnValue, ExtraInput, Alloc, U>>,
46    results: FixedQueue<JobReply<ReturnValue>>,
47    shutdown: bool,
48    immediate_shutdown: bool,
49    num_in_progress: usize,
50    cur_work_id: u64,
51}
52impl<
53        ReturnValue: Send + 'static,
54        ExtraInput: Send + 'static,
55        Alloc: BrotliAlloc + Send + 'static,
56        U: Send + 'static + Sync,
57    > Default for WorkQueue<ReturnValue, ExtraInput, Alloc, U>
58{
59    fn default() -> Self {
60        WorkQueue {
61            jobs: FixedQueue::default(),
62            results: FixedQueue::default(),
63            num_in_progress: 0,
64            immediate_shutdown: false,
65            shutdown: false,
66            cur_work_id: 0,
67        }
68    }
69}
70
71pub struct GuardedQueue<
72    ReturnValue: Send + 'static,
73    ExtraInput: Send + 'static,
74    Alloc: BrotliAlloc + Send + 'static,
75    U: Send + 'static + Sync,
76>(Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>);
77pub struct WorkerPool<
78    ReturnValue: Send + 'static,
79    ExtraInput: Send + 'static,
80    Alloc: BrotliAlloc + Send + 'static,
81    U: Send + 'static + Sync,
82> {
83    queue: GuardedQueue<ReturnValue, ExtraInput, Alloc, U>,
84    join: [Option<std::thread::JoinHandle<()>>; MAX_THREADS],
85}
86
87impl<
88        ReturnValue: Send + 'static,
89        ExtraInput: Send + 'static,
90        Alloc: BrotliAlloc + Send + 'static,
91        U: Send + 'static + Sync,
92    > Drop for WorkerPool<ReturnValue, ExtraInput, Alloc, U>
93{
94    fn drop(&mut self) {
95        {
96            let (lock, cvar) = &*self.queue.0;
97            let mut local_queue = lock.lock().unwrap();
98            local_queue.immediate_shutdown = true;
99            cvar.notify_all();
100        }
101        for thread_handle in self.join.iter_mut() {
102            if let Some(th) = thread_handle.take() {
103                th.join().unwrap();
104            }
105        }
106    }
107}
108impl<
109        ReturnValue: Send + 'static,
110        ExtraInput: Send + 'static,
111        Alloc: BrotliAlloc + Send + 'static,
112        U: Send + 'static + Sync,
113    > WorkerPool<ReturnValue, ExtraInput, Alloc, U>
114{
115    fn do_work(queue: Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>) {
116        loop {
117            let ret;
118            {
119                // need to drop possible job before the final lock is taken,
120                // so refcount of possible_job Arc is 0 by the time the job is delivered
121                // to the caller. We basically need a barrier (the lock) to happen
122                // after the destructor that decrefs possible_job
123                let possible_job;
124                {
125                    let (lock, cvar) = &*queue;
126                    let mut local_queue = lock.lock().unwrap();
127                    if local_queue.immediate_shutdown {
128                        break;
129                    }
130                    possible_job = if let Some(res) = local_queue.jobs.pop() {
131                        cvar.notify_all();
132                        local_queue.num_in_progress += 1;
133                        res
134                    } else if local_queue.shutdown {
135                        break;
136                    } else {
137                        let _lock = cvar.wait(local_queue); // unlock immediately, unfortunately
138                        continue;
139                    };
140                }
141                ret = if let Ok(job_data) = possible_job.data.read() {
142                    JobReply {
143                        result: (possible_job.func)(
144                            possible_job.extra_input,
145                            possible_job.index,
146                            possible_job.thread_size,
147                            &*job_data,
148                            possible_job.alloc,
149                        ),
150                        work_id: possible_job.work_id,
151                    }
152                } else {
153                    break; // poisoned lock
154                };
155            }
156            {
157                let (lock, cvar) = &*queue;
158                let mut local_queue = lock.lock().unwrap();
159                local_queue.num_in_progress -= 1;
160                local_queue.results.push(ret).unwrap();
161                cvar.notify_all();
162            }
163        }
164    }
165    fn _push_job(&mut self, job: JobRequest<ReturnValue, ExtraInput, Alloc, U>) {
166        let (lock, cvar) = &*self.queue.0;
167        let mut local_queue = lock.lock().unwrap();
168        loop {
169            if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
170                < MAX_THREADS
171            {
172                local_queue.jobs.push(job).unwrap();
173                cvar.notify_all();
174                break;
175            }
176            local_queue = cvar.wait(local_queue).unwrap();
177        }
178    }
179    fn _try_push_job(
180        &mut self,
181        job: JobRequest<ReturnValue, ExtraInput, Alloc, U>,
182    ) -> Result<(), JobRequest<ReturnValue, ExtraInput, Alloc, U>> {
183        let (lock, cvar) = &*self.queue.0;
184        let mut local_queue = lock.lock().unwrap();
185        if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
186            < MAX_THREADS
187        {
188            local_queue.jobs.push(job).unwrap();
189            cvar.notify_all();
190            Ok(())
191        } else {
192            Err(job)
193        }
194    }
195    fn start(
196        queue: Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>,
197    ) -> std::thread::JoinHandle<()> {
198        std::thread::spawn(move || Self::do_work(queue))
199    }
200    pub fn new(num_threads: usize) -> Self {
201        let queue = Arc::new((Mutex::new(WorkQueue::default()), Condvar::new()));
202        WorkerPool {
203            queue: GuardedQueue(queue.clone()),
204            join: [
205                Some(Self::start(queue.clone())),
206                if 1 < num_threads {
207                    Some(Self::start(queue.clone()))
208                } else {
209                    None
210                },
211                if 2 < num_threads {
212                    Some(Self::start(queue.clone()))
213                } else {
214                    None
215                },
216                if 3 < num_threads {
217                    Some(Self::start(queue.clone()))
218                } else {
219                    None
220                },
221                if 4 < num_threads {
222                    Some(Self::start(queue.clone()))
223                } else {
224                    None
225                },
226                if 5 < num_threads {
227                    Some(Self::start(queue.clone()))
228                } else {
229                    None
230                },
231                if 6 < num_threads {
232                    Some(Self::start(queue.clone()))
233                } else {
234                    None
235                },
236                if 7 < num_threads {
237                    Some(Self::start(queue.clone()))
238                } else {
239                    None
240                },
241                if 8 < num_threads {
242                    Some(Self::start(queue.clone()))
243                } else {
244                    None
245                },
246                if 9 < num_threads {
247                    Some(Self::start(queue.clone()))
248                } else {
249                    None
250                },
251                if 10 < num_threads {
252                    Some(Self::start(queue.clone()))
253                } else {
254                    None
255                },
256                if 11 < num_threads {
257                    Some(Self::start(queue.clone()))
258                } else {
259                    None
260                },
261                if 12 < num_threads {
262                    Some(Self::start(queue.clone()))
263                } else {
264                    None
265                },
266                if 13 < num_threads {
267                    Some(Self::start(queue.clone()))
268                } else {
269                    None
270                },
271                if 14 < num_threads {
272                    Some(Self::start(queue.clone()))
273                } else {
274                    None
275                },
276                if 15 < num_threads {
277                    Some(Self::start(queue.clone()))
278                } else {
279                    None
280                },
281            ],
282        }
283    }
284}
285
286pub fn new_work_pool<
287    Alloc: BrotliAlloc + Send + 'static,
288    SliceW: SliceWrapper<u8> + Send + 'static + Sync,
289>(
290    num_threads: usize,
291) -> WorkerPool<
292    CompressionThreadResult<Alloc>,
293    UnionHasher<Alloc>,
294    Alloc,
295    (SliceW, BrotliEncoderParams),
296>
297where
298    <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
299    <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
300    <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
301{
302    WorkerPool::new(num_threads)
303}
304
305pub struct WorkerJoinable<
306    ReturnValue: Send + 'static,
307    ExtraInput: Send + 'static,
308    Alloc: BrotliAlloc + Send + 'static,
309    U: Send + 'static + Sync,
310> {
311    queue: GuardedQueue<ReturnValue, ExtraInput, Alloc, U>,
312    work_id: u64,
313}
314impl<
315        ReturnValue: Send + 'static,
316        ExtraInput: Send + 'static,
317        Alloc: BrotliAlloc + Send + 'static,
318        U: Send + 'static + Sync,
319    > Joinable<ReturnValue, BrotliEncoderThreadError>
320    for WorkerJoinable<ReturnValue, ExtraInput, Alloc, U>
321{
322    fn join(self) -> Result<ReturnValue, BrotliEncoderThreadError> {
323        let (lock, cvar) = &*self.queue.0;
324        let mut local_queue = lock.lock().unwrap();
325        loop {
326            match local_queue
327                .results
328                .remove(|data: &Option<JobReply<ReturnValue>>| {
329                    if let Some(ref item) = *data {
330                        item.work_id == self.work_id
331                    } else {
332                        false
333                    }
334                }) {
335                Some(matched) => return Ok(matched.result),
336                None => local_queue = cvar.wait(local_queue).unwrap(),
337            };
338        }
339    }
340}
341
342impl<
343        ReturnValue: Send + 'static,
344        ExtraInput: Send + 'static,
345        Alloc: BrotliAlloc + Send + 'static,
346        U: Send + 'static + Sync,
347    > BatchSpawnableLite<ReturnValue, ExtraInput, Alloc, U>
348    for WorkerPool<ReturnValue, ExtraInput, Alloc, U>
349where
350    <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
351    <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
352    <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
353{
354    type FinalJoinHandle = Arc<RwLock<U>>;
355    type JoinHandle = WorkerJoinable<ReturnValue, ExtraInput, Alloc, U>;
356
357    fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
358        std::sync::Arc::<RwLock<U>>::new(RwLock::new(
359            mem::replace(input, Owned(InternalOwned::Borrowed)).unwrap(),
360        ))
361    }
362    fn spawn(
363        &mut self,
364        locked_input: &mut Self::FinalJoinHandle,
365        work: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
366        index: usize,
367        num_threads: usize,
368        f: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
369    ) {
370        assert!(num_threads <= MAX_THREADS);
371        let (lock, cvar) = &*self.queue.0;
372        let mut local_queue = lock.lock().unwrap();
373        loop {
374            if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
375                <= MAX_THREADS
376            {
377                let work_id = local_queue.cur_work_id;
378                local_queue.cur_work_id += 1;
379                let (local_alloc, local_extra) = work.replace_with_default();
380                local_queue
381                    .jobs
382                    .push(JobRequest {
383                        func: f,
384                        extra_input: local_extra,
385                        index,
386                        thread_size: num_threads,
387                        data: locked_input.clone(),
388                        alloc: local_alloc,
389                        work_id,
390                    })
391                    .unwrap();
392                *work = SendAlloc(InternalSendAlloc::Join(WorkerJoinable {
393                    queue: GuardedQueue(self.queue.0.clone()),
394                    work_id,
395                }));
396                cvar.notify_all();
397                break;
398            } else {
399                local_queue = cvar.wait(local_queue).unwrap(); // hope room frees up
400            }
401        }
402    }
403}
404
405pub fn compress_worker_pool<
406    Alloc: BrotliAlloc + Send + 'static,
407    SliceW: SliceWrapper<u8> + Send + 'static + Sync,
408>(
409    params: &BrotliEncoderParams,
410    owned_input: &mut Owned<SliceW>,
411    output: &mut [u8],
412    alloc_per_thread: &mut [SendAlloc<
413        CompressionThreadResult<Alloc>,
414        UnionHasher<Alloc>,
415        Alloc,
416        <WorkerPool<
417            CompressionThreadResult<Alloc>,
418            UnionHasher<Alloc>,
419            Alloc,
420            (SliceW, BrotliEncoderParams),
421        > as BatchSpawnableLite<
422            CompressionThreadResult<Alloc>,
423            UnionHasher<Alloc>,
424            Alloc,
425            (SliceW, BrotliEncoderParams),
426        >>::JoinHandle,
427    >],
428    work_pool: &mut WorkerPool<
429        CompressionThreadResult<Alloc>,
430        UnionHasher<Alloc>,
431        Alloc,
432        (SliceW, BrotliEncoderParams),
433    >,
434) -> Result<usize, BrotliEncoderThreadError>
435where
436    <Alloc as Allocator<u8>>::AllocatedMemory: Send,
437    <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
438    <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
439{
440    CompressMulti(params, owned_input, output, alloc_per_thread, work_pool)
441}
442
443// out of place thread create