1#![cfg(feature = "std")]
2use core::mem;
3use std;
4
5use alloc::{Allocator, SliceWrapper};
6use enc::backward_references::UnionHasher;
7use enc::fixed_queue::{FixedQueue, MAX_THREADS};
8use enc::threading::{
9 BatchSpawnableLite, BrotliEncoderThreadError, CompressMulti, CompressionThreadResult,
10 InternalOwned, InternalSendAlloc, Joinable, Owned, SendAlloc,
11};
12use enc::BrotliAlloc;
13use enc::BrotliEncoderParams;
14use std::sync::{Arc, Condvar, Mutex};
15use std::sync::RwLock;
18
19struct JobReply<T: Send + 'static> {
20 result: T,
21 work_id: u64,
22}
23
24struct JobRequest<
25 ReturnValue: Send + 'static,
26 ExtraInput: Send + 'static,
27 Alloc: BrotliAlloc + Send + 'static,
28 U: Send + 'static + Sync,
29> {
30 func: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
31 extra_input: ExtraInput,
32 index: usize,
33 thread_size: usize,
34 data: Arc<RwLock<U>>,
35 alloc: Alloc,
36 work_id: u64,
37}
38
39struct WorkQueue<
40 ReturnValue: Send + 'static,
41 ExtraInput: Send + 'static,
42 Alloc: BrotliAlloc + Send + 'static,
43 U: Send + 'static + Sync,
44> {
45 jobs: FixedQueue<JobRequest<ReturnValue, ExtraInput, Alloc, U>>,
46 results: FixedQueue<JobReply<ReturnValue>>,
47 shutdown: bool,
48 immediate_shutdown: bool,
49 num_in_progress: usize,
50 cur_work_id: u64,
51}
52impl<
53 ReturnValue: Send + 'static,
54 ExtraInput: Send + 'static,
55 Alloc: BrotliAlloc + Send + 'static,
56 U: Send + 'static + Sync,
57 > Default for WorkQueue<ReturnValue, ExtraInput, Alloc, U>
58{
59 fn default() -> Self {
60 WorkQueue {
61 jobs: FixedQueue::default(),
62 results: FixedQueue::default(),
63 num_in_progress: 0,
64 immediate_shutdown: false,
65 shutdown: false,
66 cur_work_id: 0,
67 }
68 }
69}
70
71pub struct GuardedQueue<
72 ReturnValue: Send + 'static,
73 ExtraInput: Send + 'static,
74 Alloc: BrotliAlloc + Send + 'static,
75 U: Send + 'static + Sync,
76>(Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>);
77pub struct WorkerPool<
78 ReturnValue: Send + 'static,
79 ExtraInput: Send + 'static,
80 Alloc: BrotliAlloc + Send + 'static,
81 U: Send + 'static + Sync,
82> {
83 queue: GuardedQueue<ReturnValue, ExtraInput, Alloc, U>,
84 join: [Option<std::thread::JoinHandle<()>>; MAX_THREADS],
85}
86
87impl<
88 ReturnValue: Send + 'static,
89 ExtraInput: Send + 'static,
90 Alloc: BrotliAlloc + Send + 'static,
91 U: Send + 'static + Sync,
92 > Drop for WorkerPool<ReturnValue, ExtraInput, Alloc, U>
93{
94 fn drop(&mut self) {
95 {
96 let (lock, cvar) = &*self.queue.0;
97 let mut local_queue = lock.lock().unwrap();
98 local_queue.immediate_shutdown = true;
99 cvar.notify_all();
100 }
101 for thread_handle in self.join.iter_mut() {
102 if let Some(th) = thread_handle.take() {
103 th.join().unwrap();
104 }
105 }
106 }
107}
108impl<
109 ReturnValue: Send + 'static,
110 ExtraInput: Send + 'static,
111 Alloc: BrotliAlloc + Send + 'static,
112 U: Send + 'static + Sync,
113 > WorkerPool<ReturnValue, ExtraInput, Alloc, U>
114{
115 fn do_work(queue: Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>) {
116 loop {
117 let ret;
118 {
119 let possible_job;
124 {
125 let (lock, cvar) = &*queue;
126 let mut local_queue = lock.lock().unwrap();
127 if local_queue.immediate_shutdown {
128 break;
129 }
130 possible_job = if let Some(res) = local_queue.jobs.pop() {
131 cvar.notify_all();
132 local_queue.num_in_progress += 1;
133 res
134 } else if local_queue.shutdown {
135 break;
136 } else {
137 let _lock = cvar.wait(local_queue); continue;
139 };
140 }
141 ret = if let Ok(job_data) = possible_job.data.read() {
142 JobReply {
143 result: (possible_job.func)(
144 possible_job.extra_input,
145 possible_job.index,
146 possible_job.thread_size,
147 &*job_data,
148 possible_job.alloc,
149 ),
150 work_id: possible_job.work_id,
151 }
152 } else {
153 break; };
155 }
156 {
157 let (lock, cvar) = &*queue;
158 let mut local_queue = lock.lock().unwrap();
159 local_queue.num_in_progress -= 1;
160 local_queue.results.push(ret).unwrap();
161 cvar.notify_all();
162 }
163 }
164 }
165 fn _push_job(&mut self, job: JobRequest<ReturnValue, ExtraInput, Alloc, U>) {
166 let (lock, cvar) = &*self.queue.0;
167 let mut local_queue = lock.lock().unwrap();
168 loop {
169 if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
170 < MAX_THREADS
171 {
172 local_queue.jobs.push(job).unwrap();
173 cvar.notify_all();
174 break;
175 }
176 local_queue = cvar.wait(local_queue).unwrap();
177 }
178 }
179 fn _try_push_job(
180 &mut self,
181 job: JobRequest<ReturnValue, ExtraInput, Alloc, U>,
182 ) -> Result<(), JobRequest<ReturnValue, ExtraInput, Alloc, U>> {
183 let (lock, cvar) = &*self.queue.0;
184 let mut local_queue = lock.lock().unwrap();
185 if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
186 < MAX_THREADS
187 {
188 local_queue.jobs.push(job).unwrap();
189 cvar.notify_all();
190 Ok(())
191 } else {
192 Err(job)
193 }
194 }
195 fn start(
196 queue: Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>,
197 ) -> std::thread::JoinHandle<()> {
198 std::thread::spawn(move || Self::do_work(queue))
199 }
200 pub fn new(num_threads: usize) -> Self {
201 let queue = Arc::new((Mutex::new(WorkQueue::default()), Condvar::new()));
202 WorkerPool {
203 queue: GuardedQueue(queue.clone()),
204 join: [
205 Some(Self::start(queue.clone())),
206 if 1 < num_threads {
207 Some(Self::start(queue.clone()))
208 } else {
209 None
210 },
211 if 2 < num_threads {
212 Some(Self::start(queue.clone()))
213 } else {
214 None
215 },
216 if 3 < num_threads {
217 Some(Self::start(queue.clone()))
218 } else {
219 None
220 },
221 if 4 < num_threads {
222 Some(Self::start(queue.clone()))
223 } else {
224 None
225 },
226 if 5 < num_threads {
227 Some(Self::start(queue.clone()))
228 } else {
229 None
230 },
231 if 6 < num_threads {
232 Some(Self::start(queue.clone()))
233 } else {
234 None
235 },
236 if 7 < num_threads {
237 Some(Self::start(queue.clone()))
238 } else {
239 None
240 },
241 if 8 < num_threads {
242 Some(Self::start(queue.clone()))
243 } else {
244 None
245 },
246 if 9 < num_threads {
247 Some(Self::start(queue.clone()))
248 } else {
249 None
250 },
251 if 10 < num_threads {
252 Some(Self::start(queue.clone()))
253 } else {
254 None
255 },
256 if 11 < num_threads {
257 Some(Self::start(queue.clone()))
258 } else {
259 None
260 },
261 if 12 < num_threads {
262 Some(Self::start(queue.clone()))
263 } else {
264 None
265 },
266 if 13 < num_threads {
267 Some(Self::start(queue.clone()))
268 } else {
269 None
270 },
271 if 14 < num_threads {
272 Some(Self::start(queue.clone()))
273 } else {
274 None
275 },
276 if 15 < num_threads {
277 Some(Self::start(queue.clone()))
278 } else {
279 None
280 },
281 ],
282 }
283 }
284}
285
286pub fn new_work_pool<
287 Alloc: BrotliAlloc + Send + 'static,
288 SliceW: SliceWrapper<u8> + Send + 'static + Sync,
289>(
290 num_threads: usize,
291) -> WorkerPool<
292 CompressionThreadResult<Alloc>,
293 UnionHasher<Alloc>,
294 Alloc,
295 (SliceW, BrotliEncoderParams),
296>
297where
298 <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
299 <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
300 <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
301{
302 WorkerPool::new(num_threads)
303}
304
305pub struct WorkerJoinable<
306 ReturnValue: Send + 'static,
307 ExtraInput: Send + 'static,
308 Alloc: BrotliAlloc + Send + 'static,
309 U: Send + 'static + Sync,
310> {
311 queue: GuardedQueue<ReturnValue, ExtraInput, Alloc, U>,
312 work_id: u64,
313}
314impl<
315 ReturnValue: Send + 'static,
316 ExtraInput: Send + 'static,
317 Alloc: BrotliAlloc + Send + 'static,
318 U: Send + 'static + Sync,
319 > Joinable<ReturnValue, BrotliEncoderThreadError>
320 for WorkerJoinable<ReturnValue, ExtraInput, Alloc, U>
321{
322 fn join(self) -> Result<ReturnValue, BrotliEncoderThreadError> {
323 let (lock, cvar) = &*self.queue.0;
324 let mut local_queue = lock.lock().unwrap();
325 loop {
326 match local_queue
327 .results
328 .remove(|data: &Option<JobReply<ReturnValue>>| {
329 if let Some(ref item) = *data {
330 item.work_id == self.work_id
331 } else {
332 false
333 }
334 }) {
335 Some(matched) => return Ok(matched.result),
336 None => local_queue = cvar.wait(local_queue).unwrap(),
337 };
338 }
339 }
340}
341
342impl<
343 ReturnValue: Send + 'static,
344 ExtraInput: Send + 'static,
345 Alloc: BrotliAlloc + Send + 'static,
346 U: Send + 'static + Sync,
347 > BatchSpawnableLite<ReturnValue, ExtraInput, Alloc, U>
348 for WorkerPool<ReturnValue, ExtraInput, Alloc, U>
349where
350 <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
351 <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
352 <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
353{
354 type FinalJoinHandle = Arc<RwLock<U>>;
355 type JoinHandle = WorkerJoinable<ReturnValue, ExtraInput, Alloc, U>;
356
357 fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
358 std::sync::Arc::<RwLock<U>>::new(RwLock::new(
359 mem::replace(input, Owned(InternalOwned::Borrowed)).unwrap(),
360 ))
361 }
362 fn spawn(
363 &mut self,
364 locked_input: &mut Self::FinalJoinHandle,
365 work: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
366 index: usize,
367 num_threads: usize,
368 f: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
369 ) {
370 assert!(num_threads <= MAX_THREADS);
371 let (lock, cvar) = &*self.queue.0;
372 let mut local_queue = lock.lock().unwrap();
373 loop {
374 if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
375 <= MAX_THREADS
376 {
377 let work_id = local_queue.cur_work_id;
378 local_queue.cur_work_id += 1;
379 let (local_alloc, local_extra) = work.replace_with_default();
380 local_queue
381 .jobs
382 .push(JobRequest {
383 func: f,
384 extra_input: local_extra,
385 index,
386 thread_size: num_threads,
387 data: locked_input.clone(),
388 alloc: local_alloc,
389 work_id,
390 })
391 .unwrap();
392 *work = SendAlloc(InternalSendAlloc::Join(WorkerJoinable {
393 queue: GuardedQueue(self.queue.0.clone()),
394 work_id,
395 }));
396 cvar.notify_all();
397 break;
398 } else {
399 local_queue = cvar.wait(local_queue).unwrap(); }
401 }
402 }
403}
404
405pub fn compress_worker_pool<
406 Alloc: BrotliAlloc + Send + 'static,
407 SliceW: SliceWrapper<u8> + Send + 'static + Sync,
408>(
409 params: &BrotliEncoderParams,
410 owned_input: &mut Owned<SliceW>,
411 output: &mut [u8],
412 alloc_per_thread: &mut [SendAlloc<
413 CompressionThreadResult<Alloc>,
414 UnionHasher<Alloc>,
415 Alloc,
416 <WorkerPool<
417 CompressionThreadResult<Alloc>,
418 UnionHasher<Alloc>,
419 Alloc,
420 (SliceW, BrotliEncoderParams),
421 > as BatchSpawnableLite<
422 CompressionThreadResult<Alloc>,
423 UnionHasher<Alloc>,
424 Alloc,
425 (SliceW, BrotliEncoderParams),
426 >>::JoinHandle,
427 >],
428 work_pool: &mut WorkerPool<
429 CompressionThreadResult<Alloc>,
430 UnionHasher<Alloc>,
431 Alloc,
432 (SliceW, BrotliEncoderParams),
433 >,
434) -> Result<usize, BrotliEncoderThreadError>
435where
436 <Alloc as Allocator<u8>>::AllocatedMemory: Send,
437 <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
438 <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
439{
440 CompressMulti(params, owned_input, output, alloc_per_thread, work_pool)
441}
442
443