1#[cfg(feature="threading")]
5use std::cmp::min;
6#[cfg(feature="threading")]
7use std::str::FromStr;
8#[cfg(feature="threading")]
9use once_cell::sync::Lazy;
10
11#[cfg(feature="threading")]
12pub use thread_tree::ThreadTree as ThreadPool;
13#[cfg(feature="threading")]
14pub use thread_tree::ThreadTreeCtx as ThreadPoolCtx;
15
16use crate::kernel::GemmKernel;
17use crate::util::RangeChunk;
18
19#[cfg(not(feature="threading"))]
21pub(crate) struct ThreadPool;
22
23#[cfg(not(feature="threading"))]
24pub(crate) type ThreadPoolCtx<'a> = &'a ();
25
26#[cfg(not(feature="threading"))]
27impl ThreadPool {
28 pub(crate) fn top(&self) -> ThreadPoolCtx<'_> { &() }
30}
31
32pub(crate) fn get_thread_pool<'a>() -> (usize, ThreadPoolCtx<'a>) {
33 let reg = &*REGISTRY;
34 (reg.nthreads, reg.thread_pool().top())
35}
36
37struct Registry {
38 nthreads: usize,
39 #[cfg(feature="threading")]
40 thread_pool: Box<ThreadPool>,
41}
42
43impl Registry {
44 fn thread_pool(&self) -> &ThreadPool {
45 #[cfg(feature="threading")]
46 return &*REGISTRY.thread_pool;
47 #[cfg(not(feature="threading"))]
48 return &ThreadPool;
49 }
50}
51
52#[cfg(not(feature="threading"))]
53const REGISTRY: &'static Registry = &Registry { nthreads: 1 };
54
55#[cfg(feature="threading")]
56const MAX_THREADS: usize = 4;
58
59#[cfg(feature="threading")]
60static REGISTRY: Lazy<Registry> = Lazy::new(|| {
61 let var = ::std::env::var("MATMUL_NUM_THREADS").ok();
62 let threads = match var {
63 Some(s) if !s.is_empty() => {
64 if let Ok(nt) = usize::from_str(&s) {
65 nt
66 } else {
67 eprintln!("Failed to parse MATMUL_NUM_THREADS");
68 1
69 }
70 }
71 _otherwise => num_cpus::get_physical(),
72 };
73
74 let threads = 1.max(threads).min(MAX_THREADS);
76
77 let tp = if threads <= 1 {
78 Box::new(ThreadPool::new_level0())
79 } else if threads <= 3 {
80 ThreadPool::new_with_level(1)
81 } else {
82 ThreadPool::new_with_level(2)
83 };
84
85 Registry {
86 nthreads: threads,
87 thread_pool: tp,
88 }
89});
90
91#[derive(Copy, Clone)]
93pub(crate) struct LoopThreadConfig {
94 pub(crate) loop3: u8,
96 pub(crate) loop2: u8,
98}
99
100impl LoopThreadConfig {
101 pub(crate) fn new<K>(m: usize, k: usize, n: usize, max_threads: usize) -> Self
103 where K: GemmKernel
104 {
105 let default_config = LoopThreadConfig { loop3: 1, loop2: 1 };
106
107 #[cfg(not(feature="threading"))]
108 {
109 let _ = (m, k, n, max_threads); return default_config;
111 }
112
113 #[cfg(feature="threading")]
114 {
115 if max_threads == 1 {
116 return default_config;
117 }
118
119 Self::new_impl(m, k, n, max_threads, K::mc())
120 }
121 }
122
123 #[cfg(feature="threading")]
124 fn new_impl(m: usize, k: usize, n: usize, max_threads: usize, kmc: usize) -> Self {
125 let size_factor = m * k + k * n;
127 let thread_factor = 1 << 14;
128 let arch_factor = if cfg!(target_arch="arm") {
130 20
131 } else {
132 1
133 };
134
135 let matrix_max_threads = size_factor / (thread_factor / arch_factor);
141 let mut max_threads = max_threads.min(matrix_max_threads);
142
143 let loop3 = if max_threads >= 2 && m >= 3 * (kmc / 2) {
144 max_threads /= 2;
145 2
146 } else {
147 1
148 };
149 let loop2 = if max_threads >= 2 { 2 } else { 1 };
150
151 LoopThreadConfig {
152 loop3,
153 loop2,
154 }
155 }
156
157 #[inline(always)]
159 pub(crate) fn num_pack_a(&self) -> usize { self.loop3 as usize }
160}
161
162
163impl RangeChunk {
164 pub(crate) fn parallel(self, nthreads: u8, pool: ThreadPoolCtx) -> RangeChunkParallel<fn()> {
166 fn nop() {}
167
168 RangeChunkParallel {
169 nthreads,
170 pool,
171 range: self,
172 thread_local: nop,
173 }
174 }
175}
176
177pub(crate) struct RangeChunkParallel<'a, G> {
179 range: RangeChunk,
180 nthreads: u8,
181 pool: ThreadPoolCtx<'a>,
182 thread_local: G,
183}
184
185impl<'a, G> RangeChunkParallel<'a, G> {
186 #[cfg(feature="threading")]
187 pub(crate) fn thread_local<G2, R>(self, func: G2) -> RangeChunkParallel<'a, G2>
189 where G2: Fn(usize, usize) -> R + Sync
190 {
191 RangeChunkParallel {
192 nthreads: self.nthreads,
193 pool: self.pool,
194 thread_local: func,
195 range: self.range,
196 }
197 }
198
199 #[cfg(not(feature="threading"))]
200 pub(crate) fn thread_local<G2, R>(self, func: G2) -> RangeChunkParallel<'a, G2>
202 where G2: FnOnce(usize, usize) -> R + Sync
203 {
204 RangeChunkParallel {
205 nthreads: self.nthreads,
206 pool: self.pool,
207 thread_local: func,
208 range: self.range,
209 }
210 }
211}
212
213#[cfg(not(feature="threading"))]
214impl<G, R> RangeChunkParallel<'_, G>
215 where G: FnOnce(usize, usize) -> R + Sync,
216{
217 pub(crate) fn for_each<F>(self, for_each: F)
218 where F: Fn(ThreadPoolCtx<'_>, &mut R, usize, usize) + Sync,
219 {
220 let mut local = (self.thread_local)(0, 1);
221 for (ln, chunk_size) in self.range {
222 for_each(self.pool, &mut local, ln, chunk_size)
223 }
224 }
225}
226
227
228#[cfg(feature="threading")]
229impl<G, R> RangeChunkParallel<'_, G>
230 where G: Fn(usize, usize) -> R + Sync,
231{
232 pub(crate) fn for_each<F>(self, for_each: F)
241 where F: Fn(ThreadPoolCtx<'_>, &mut R, usize, usize) + Sync,
242 {
243 fn inner<F, G, R>(range: RangeChunk, index: usize, nthreads: usize, pool: ThreadPoolCtx<'_>,
244 thread_local: G, for_each: F)
245 where G: Fn(usize, usize) -> R + Sync,
246 F: Fn(ThreadPoolCtx<'_>, &mut R, usize, usize) + Sync
247 {
248 let mut local = thread_local(index, nthreads);
249 for (ln, chunk_size) in range.part(index, nthreads) {
250 for_each(pool, &mut local, ln, chunk_size)
251 }
252 }
253
254 debug_assert!(self.nthreads <= 4, "this method does not support nthreads > 4, got {}",
255 self.nthreads);
256 let pool = self.pool;
257 let range = self.range;
258 let for_each = &for_each;
259 let local = &self.thread_local;
260 let nthreads = min(self.nthreads as usize, 4);
261 let f = move |ctx: ThreadPoolCtx<'_>, i| inner(range, i, nthreads, ctx, local, for_each);
262 if nthreads >= 4 {
263 pool.join4(&f);
264 } else if nthreads >= 3 {
265 pool.join3l(&f);
266 } else if nthreads >= 2 {
267 pool.join(|ctx| f(ctx, 0), |ctx| f(ctx, 1));
268 } else {
269 f(pool, 0)
270 }
271 }
272
273}
274