matrixmultiply/
threading.rs

1///
2/// Threading support functions and statics
3
4#[cfg(feature="threading")]
5use std::cmp::min;
6#[cfg(feature="threading")]
7use std::str::FromStr;
8#[cfg(feature="threading")]
9use once_cell::sync::Lazy;
10
11#[cfg(feature="threading")]
12pub use thread_tree::ThreadTree as ThreadPool;
13#[cfg(feature="threading")]
14pub use thread_tree::ThreadTreeCtx as ThreadPoolCtx;
15
16use crate::kernel::GemmKernel;
17use crate::util::RangeChunk;
18
19/// Dummy threadpool
20#[cfg(not(feature="threading"))]
21pub(crate) struct ThreadPool;
22
23#[cfg(not(feature="threading"))]
24pub(crate) type ThreadPoolCtx<'a> = &'a ();
25
26#[cfg(not(feature="threading"))]
27impl ThreadPool {
28    /// Get top dummy thread pool context
29    pub(crate) fn top(&self) -> ThreadPoolCtx<'_> { &() }
30}
31
32pub(crate) fn get_thread_pool<'a>() -> (usize, ThreadPoolCtx<'a>) {
33    let reg = &*REGISTRY;
34    (reg.nthreads, reg.thread_pool().top())
35}
36
37struct Registry {
38    nthreads: usize,
39    #[cfg(feature="threading")]
40    thread_pool: Box<ThreadPool>,
41}
42
43impl Registry {
44    fn thread_pool(&self) -> &ThreadPool {
45        #[cfg(feature="threading")]
46        return &*REGISTRY.thread_pool;
47        #[cfg(not(feature="threading"))]
48        return &ThreadPool;
49    }
50}
51
52#[cfg(not(feature="threading"))]
53const REGISTRY: &'static Registry = &Registry { nthreads: 1 };
54
55#[cfg(feature="threading")]
56/// Maximum (usefully) supported threads at the moment
57const MAX_THREADS: usize = 4;
58
59#[cfg(feature="threading")]
60static REGISTRY: Lazy<Registry> = Lazy::new(|| {
61    let var = ::std::env::var("MATMUL_NUM_THREADS").ok();
62    let threads = match var {
63        Some(s) if !s.is_empty() => {
64            if let Ok(nt) = usize::from_str(&s) {
65                nt
66            } else {
67                eprintln!("Failed to parse MATMUL_NUM_THREADS");
68                1
69            }
70        }
71        _otherwise => num_cpus::get_physical(),
72    };
73
74    // Ensure threads in 1 <= threads <= MAX_THREADS
75    let threads = 1.max(threads).min(MAX_THREADS);
76
77    let tp = if threads <= 1 {
78        Box::new(ThreadPool::new_level0())
79    } else if threads <= 3 {
80        ThreadPool::new_with_level(1)
81    } else {
82        ThreadPool::new_with_level(2)
83    };
84
85    Registry {
86        nthreads: threads,
87        thread_pool: tp,
88    }
89});
90
91/// Describe how many threads we use in each loop
92#[derive(Copy, Clone)]
93pub(crate) struct LoopThreadConfig {
94    /// Loop 3 threads
95    pub(crate) loop3: u8,
96    /// Loop 2 threads
97    pub(crate) loop2: u8,
98}
99
100impl LoopThreadConfig {
101    /// Decide how many threads to use in each loop
102    pub(crate) fn new<K>(m: usize, k: usize, n: usize, max_threads: usize) -> Self
103        where K: GemmKernel
104    {
105        let default_config = LoopThreadConfig { loop3: 1, loop2: 1 };
106
107        #[cfg(not(feature="threading"))]
108        {
109            let _ = (m, k, n, max_threads); // used
110            return default_config;
111        }
112
113        #[cfg(feature="threading")]
114        {
115            if max_threads == 1 {
116                return default_config;
117            }
118
119            Self::new_impl(m, k, n, max_threads, K::mc())
120        }
121    }
122
123    #[cfg(feature="threading")]
124    fn new_impl(m: usize, k: usize, n: usize, max_threads: usize, kmc: usize) -> Self {
125        // use a heuristic to try not to use too many threads for smaller matrices
126        let size_factor = m * k + k * n;
127        let thread_factor = 1 << 14;
128        // pure guesswork in terms of what the default should be
129        let arch_factor = if cfg!(target_arch="arm") {
130            20
131        } else {
132            1
133        };
134
135        // At the moment only a configuration of 1, 2, or 4 threads is supported.
136        //
137        // Prefer to split Loop 3 if only 2 threads are available, (because it was better in a
138        // square matrix benchmark).
139
140        let matrix_max_threads = size_factor / (thread_factor / arch_factor);
141        let mut max_threads = max_threads.min(matrix_max_threads);
142
143        let loop3 = if max_threads >= 2 && m >= 3 * (kmc / 2) {
144            max_threads /= 2;
145            2
146        } else {
147            1
148        };
149        let loop2 = if max_threads >= 2 { 2 } else { 1 };
150
151        LoopThreadConfig {
152            loop3,
153            loop2,
154        }
155    }
156
157    /// Number of packing buffers for A
158    #[inline(always)]
159    pub(crate) fn num_pack_a(&self) -> usize { self.loop3 as usize }
160}
161
162
163impl RangeChunk {
164    /// "Builder" method to create a RangeChunkParallel
165    pub(crate) fn parallel(self, nthreads: u8, pool: ThreadPoolCtx) -> RangeChunkParallel<fn()> {
166        fn nop() {}
167
168        RangeChunkParallel {
169            nthreads,
170            pool,
171            range: self,
172            thread_local: nop,
173        }
174    }
175}
176
177/// Intermediate struct for building the parallel execution of a range chunk.
178pub(crate) struct RangeChunkParallel<'a, G> {
179    range: RangeChunk,
180    nthreads: u8,
181    pool: ThreadPoolCtx<'a>,
182    thread_local: G,
183}
184
185impl<'a, G> RangeChunkParallel<'a, G> {
186    #[cfg(feature="threading")]
187    /// Set thread local setup function - called once per thread to setup thread local data.
188    pub(crate) fn thread_local<G2, R>(self, func: G2) -> RangeChunkParallel<'a, G2>
189        where G2: Fn(usize, usize) -> R + Sync
190    {
191        RangeChunkParallel {
192            nthreads: self.nthreads,
193            pool: self.pool,
194            thread_local: func,
195            range: self.range,
196        }
197    }
198
199    #[cfg(not(feature="threading"))]
200    /// Set thread local setup function - called once per thread to setup thread local data.
201    pub(crate) fn thread_local<G2, R>(self, func: G2) -> RangeChunkParallel<'a, G2>
202        where G2: FnOnce(usize, usize) -> R + Sync
203    {
204        RangeChunkParallel {
205            nthreads: self.nthreads,
206            pool: self.pool,
207            thread_local: func,
208            range: self.range,
209        }
210    }
211}
212
213#[cfg(not(feature="threading"))]
214impl<G, R> RangeChunkParallel<'_, G>
215    where G: FnOnce(usize, usize) -> R + Sync,
216{
217    pub(crate) fn for_each<F>(self, for_each: F)
218        where F: Fn(ThreadPoolCtx<'_>, &mut R, usize, usize) + Sync,
219    {
220        let mut local = (self.thread_local)(0, 1);
221        for (ln, chunk_size) in self.range {
222            for_each(self.pool, &mut local, ln, chunk_size)
223        }
224    }
225}
226
227
228#[cfg(feature="threading")]
229impl<G, R> RangeChunkParallel<'_, G>
230    where G: Fn(usize, usize) -> R + Sync,
231{
232    /// Execute loop iterations (parallel if enabled) using the given closure.
233    ///
234    /// The closure gets the following arguments for each iteration:
235    ///
236    /// - Thread pool context (used for child threads)
237    /// - Mutable reference to thread local data
238    /// - index of chunk (like RangeChunk)
239    /// - size of chunk (like RangeChunk)
240    pub(crate) fn for_each<F>(self, for_each: F)
241        where F: Fn(ThreadPoolCtx<'_>, &mut R, usize, usize) + Sync,
242    {
243        fn inner<F, G, R>(range: RangeChunk, index: usize, nthreads: usize, pool: ThreadPoolCtx<'_>,
244                          thread_local: G, for_each: F)
245            where G: Fn(usize, usize) -> R + Sync,
246                  F: Fn(ThreadPoolCtx<'_>, &mut R, usize, usize) + Sync
247        {
248            let mut local = thread_local(index, nthreads);
249            for (ln, chunk_size) in range.part(index, nthreads) {
250                for_each(pool, &mut local, ln, chunk_size)
251            }
252        }
253
254        debug_assert!(self.nthreads <= 4, "this method does not support nthreads > 4, got {}",
255                      self.nthreads);
256        let pool = self.pool;
257        let range = self.range;
258        let for_each = &for_each;
259        let local = &self.thread_local;
260        let nthreads = min(self.nthreads as usize, 4);
261        let f = move |ctx: ThreadPoolCtx<'_>, i| inner(range, i, nthreads, ctx, local, for_each);
262        if nthreads >= 4 {
263            pool.join4(&f);
264        } else if nthreads >= 3 {
265            pool.join3l(&f);
266        } else if nthreads >= 2 {
267            pool.join(|ctx| f(ctx, 0), |ctx| f(ctx, 1));
268        } else {
269            f(pool, 0)
270        }
271    }
272
273}
274