matrixmultiply/
gemm.rs

1// Copyright 2016 - 2018 Ulrik Sverdrup "bluss"
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(feature="std")]
10use core::cell::UnsafeCell;
11use core::cmp::min;
12use core::mem::size_of;
13use core::slice;
14
15use crate::aligned_alloc::Alloc;
16
17use crate::ptr::Ptr;
18use crate::util::range_chunk;
19use crate::util::round_up_to;
20
21use crate::kernel::Element;
22use crate::kernel::GemmKernel;
23use crate::kernel::GemmSelect;
24#[cfg(feature = "cgemm")]
25use crate::kernel::{c32, c64};
26use crate::threading::{get_thread_pool, ThreadPoolCtx, LoopThreadConfig};
27use crate::sgemm_kernel;
28use crate::dgemm_kernel;
29#[cfg(feature = "cgemm")]
30use crate::cgemm_kernel;
31#[cfg(feature = "cgemm")]
32use crate::zgemm_kernel;
33use rawpointer::PointerExt;
34
35/// General matrix multiplication (f32)
36///
37/// C ← α A B + β C
38///
39/// + m, k, n: dimensions
40/// + a, b, c: pointer to the first element in the matrix
41/// + A: m by k matrix
42/// + B: k by n matrix
43/// + C: m by n matrix
44/// + rs<em>x</em>: row stride of *x*
45/// + cs<em>x</em>: col stride of *x*
46///
47/// Strides for A and B may be arbitrary. Strides for C must not result in
48/// elements that alias each other, for example they can not be zero.
49///
50/// If β is zero, then C does not need to be initialized.
51pub unsafe fn sgemm(
52    m: usize, k: usize, n: usize,
53    alpha: f32,
54    a: *const f32, rsa: isize, csa: isize,
55    b: *const f32, rsb: isize, csb: isize,
56    beta: f32,
57    c: *mut f32, rsc: isize, csc: isize)
58{
59    sgemm_kernel::detect(GemmParameters { m, k, n,
60                alpha,
61                a, rsa, csa,
62                b, rsb, csb,
63                beta,
64                c, rsc, csc})
65}
66
67/// General matrix multiplication (f64)
68///
69/// C ← α A B + β C
70///
71/// + m, k, n: dimensions
72/// + a, b, c: pointer to the first element in the matrix
73/// + A: m by k matrix
74/// + B: k by n matrix
75/// + C: m by n matrix
76/// + rs<em>x</em>: row stride of *x*
77/// + cs<em>x</em>: col stride of *x*
78///
79/// Strides for A and B may be arbitrary. Strides for C must not result in
80/// elements that alias each other, for example they can not be zero.
81///
82/// If β is zero, then C does not need to be initialized.
83pub unsafe fn dgemm(
84    m: usize, k: usize, n: usize,
85    alpha: f64,
86    a: *const f64, rsa: isize, csa: isize,
87    b: *const f64, rsb: isize, csb: isize,
88    beta: f64,
89    c: *mut f64, rsc: isize, csc: isize)
90{
91    dgemm_kernel::detect(GemmParameters { m, k, n,
92                alpha,
93                a, rsa, csa,
94                b, rsb, csb,
95                beta,
96                c, rsc, csc})
97}
98
99/// cgemm/zgemm per-operand options
100///
101/// TBD.
102#[cfg(feature = "cgemm")]
103#[non_exhaustive]
104#[derive(Copy, Clone, Debug)]
105pub enum CGemmOption {
106    /// Standard
107    Standard,
108}
109
110#[cfg(feature = "cgemm")]
111/// General matrix multiplication (complex f32)
112///
113/// C ← α A B + β C
114///
115/// + m, k, n: dimensions
116/// + a, b, c: pointer to the first element in the matrix
117/// + A: m by k matrix
118/// + B: k by n matrix
119/// + C: m by n matrix
120/// + rs<em>x</em>: row stride of *x*
121/// + cs<em>x</em>: col stride of *x*
122///
123/// Strides for A and B may be arbitrary. Strides for C must not result in
124/// elements that alias each other, for example they can not be zero.
125///
126/// If β is zero, then C does not need to be initialized.
127///
128/// Requires crate feature `"cgemm"`
129pub unsafe fn cgemm(
130    flaga: CGemmOption, flagb: CGemmOption,
131    m: usize, k: usize, n: usize,
132    alpha: c32,
133    a: *const c32, rsa: isize, csa: isize,
134    b: *const c32, rsb: isize, csb: isize,
135    beta: c32,
136    c: *mut c32, rsc: isize, csc: isize)
137{
138    let _ = (flaga, flagb);
139    cgemm_kernel::detect(GemmParameters { m, k, n,
140                alpha,
141                a, rsa, csa,
142                b, rsb, csb,
143                beta,
144                c, rsc, csc})
145}
146
147#[cfg(feature = "cgemm")]
148/// General matrix multiplication (complex f64)
149///
150/// C ← α A B + β C
151///
152/// + m, k, n: dimensions
153/// + a, b, c: pointer to the first element in the matrix
154/// + A: m by k matrix
155/// + B: k by n matrix
156/// + C: m by n matrix
157/// + rs<em>x</em>: row stride of *x*
158/// + cs<em>x</em>: col stride of *x*
159///
160/// Strides for A and B may be arbitrary. Strides for C must not result in
161/// elements that alias each other, for example they can not be zero.
162///
163/// If β is zero, then C does not need to be initialized.
164///
165/// Requires crate feature `"cgemm"`
166pub unsafe fn zgemm(
167    flaga: CGemmOption, flagb: CGemmOption,
168    m: usize, k: usize, n: usize,
169    alpha: c64,
170    a: *const c64, rsa: isize, csa: isize,
171    b: *const c64, rsb: isize, csb: isize,
172    beta: c64,
173    c: *mut c64, rsc: isize, csc: isize)
174{
175    let _ = (flaga, flagb);
176    zgemm_kernel::detect(GemmParameters { m, k, n,
177                alpha,
178                a, rsa, csa,
179                b, rsb, csb,
180                beta,
181                c, rsc, csc})
182}
183
184struct GemmParameters<T> {
185    // Parameters grouped logically in rows
186    m: usize, k: usize, n: usize,
187    alpha: T,
188    a: *const T, rsa: isize, csa: isize,
189    beta: T,
190    b: *const T, rsb: isize, csb: isize,
191    c:   *mut T, rsc: isize, csc: isize,
192}
193
194impl<T> GemmSelect<T> for GemmParameters<T> {
195    fn select<K>(self, _kernel: K)
196       where K: GemmKernel<Elem=T>,
197             T: Element,
198    {
199        // This is where we enter with the configuration specific kernel
200        // We could cache kernel specific function pointers here, if we
201        // needed to support more constly configuration detection.
202        let GemmParameters {
203            m, k, n,
204            alpha,
205            a, rsa, csa,
206            b, rsb, csb,
207            beta,
208            c, rsc, csc} = self;
209
210        unsafe {
211            gemm_loop::<K>(
212                m, k, n,
213                alpha,
214                a, rsa, csa,
215                b, rsb, csb,
216                beta,
217                c, rsc, csc)
218        }
219    }
220}
221
222
223/// Ensure that GemmKernel parameters are supported
224/// (alignment, microkernel size).
225///
226/// This function is optimized out for a supported configuration.
227#[inline(always)]
228fn ensure_kernel_params<K>()
229    where K: GemmKernel
230{
231    let mr = K::MR;
232    let nr = K::NR;
233    // These are current limitations,
234    // can change if corresponding code in gemm_loop is updated.
235    assert!(mr > 0 && mr <= 8);
236    assert!(nr > 0 && nr <= 8);
237    assert!(mr * nr * size_of::<K::Elem>() <= 8 * 4 * 8);
238    assert!(K::align_to() <= 32);
239    // one row/col of the kernel is limiting the max align we can provide
240    let max_align = size_of::<K::Elem>() * min(mr, nr);
241    assert!(K::align_to() <= max_align);
242
243    assert!(K::MR <= K::mc());
244    assert!(K::mc() <= K::kc());
245    assert!(K::kc() <= K::nc());
246    assert!(K::nc() <= 65536);
247}
248
249/// Implement matrix multiply using packed buffers and a microkernel
250/// strategy, the type parameter `K` is the gemm microkernel.
251// no inline is best for the default case, where we support many K per
252// gemm entry point. FIXME: make this conditional on feature detection
253#[inline(never)]
254unsafe fn gemm_loop<K>(
255    m: usize, k: usize, n: usize,
256    alpha: K::Elem,
257    a: *const K::Elem, rsa: isize, csa: isize,
258    b: *const K::Elem, rsb: isize, csb: isize,
259    beta: K::Elem,
260    c: *mut K::Elem, rsc: isize, csc: isize)
261    where K: GemmKernel
262{
263    debug_assert!(m <= 1 || n == 0 || rsc != 0);
264    debug_assert!(m == 0 || n <= 1 || csc != 0);
265
266    // if A or B have no elements, compute C ← βC and return
267    if m == 0 || k == 0 || n == 0 {
268        return c_to_beta_c(m, n, beta, c, rsc, csc);
269    }
270
271    let knc = K::nc();
272    let kkc = K::kc();
273    let kmc = K::mc();
274    ensure_kernel_params::<K>();
275
276    let a = Ptr(a);
277    let b = Ptr(b);
278    let c = Ptr(c);
279
280    let (nthreads, tp) = get_thread_pool();
281    let thread_config = LoopThreadConfig::new::<K>(m, k, n, nthreads);
282    let nap = thread_config.num_pack_a();
283
284    let (mut packing_buffer, ap_size, bp_size) = make_packing_buffer::<K>(m, k, n, nap);
285    let app = Ptr(packing_buffer.ptr_mut());
286    let bpp = app.add(ap_size * nap);
287
288    // LOOP 5: split n into nc parts (B, C)
289    for (l5, nc) in range_chunk(n, knc) {
290        dprint!("LOOP 5, {}, nc={}", l5, nc);
291        let b = b.stride_offset(csb, knc * l5);
292        let c = c.stride_offset(csc, knc * l5);
293
294        // LOOP 4: split k in kc parts (A, B)
295        // This particular loop can't be parallelized because the
296        // C chunk (writable) is shared between iterations.
297        for (l4, kc) in range_chunk(k, kkc) {
298            dprint!("LOOP 4, {}, kc={}", l4, kc);
299            let b = b.stride_offset(rsb, kkc * l4);
300            let a = a.stride_offset(csa, kkc * l4);
301
302            // Pack B -> B~
303            K::pack_nr(kc, nc, slice::from_raw_parts_mut(bpp.ptr(), bp_size),
304                       b.ptr(), csb, rsb);
305
306            // First time writing to C, use user's `beta`, else accumulate
307            let betap = if l4 == 0 { beta } else { <_>::one() };
308
309            // LOOP 3: split m into mc parts (A, C)
310            range_chunk(m, kmc)
311                .parallel(thread_config.loop3, tp)
312                .thread_local(move |i, _nt| {
313                    // a packing buffer A~ per thread
314                    debug_assert!(i < nap);
315                    app.add(ap_size * i)
316                })
317                .for_each(move |tp, &mut app, l3, mc| {
318                    dprint!("LOOP 3, {}, mc={}", l3, mc);
319                    let a = a.stride_offset(rsa, kmc * l3);
320                    let c = c.stride_offset(rsc, kmc * l3);
321
322                    // Pack A -> A~
323                    K::pack_mr(kc, mc, slice::from_raw_parts_mut(app.ptr(), ap_size),
324                               a.ptr(), rsa, csa);
325
326                    // LOOP 2 and 1
327                    gemm_packed::<K>(nc, kc, mc,
328                                     alpha,
329                                     app.to_const(), bpp.to_const(),
330                                     betap,
331                                     c, rsc, csc,
332                                     tp, thread_config);
333                });
334        }
335    }
336}
337
338// set up buffer for masked (redirected output of) kernel
339const KERNEL_MAX_SIZE: usize = 8 * 8 * 4;
340const KERNEL_MAX_ALIGN: usize = 32;
341const MASK_BUF_SIZE: usize = KERNEL_MAX_SIZE + KERNEL_MAX_ALIGN - 1;
342
343// Pointers into buffer will be manually aligned anyway, due to
344// bugs we have seen on certain platforms (macos) that look like
345// we don't get aligned allocations out of TLS - 16- and 8-byte
346// allocations have been seen, make the minimal align request we can.
347// Align(32) would not work with TLS for s390x.
348#[cfg_attr(
349    not(any(
350        target_os = "macos",
351        // Target i686-win7-windows-msvc <https://github.com/rust-lang/rust/issues/138903>
352        all(
353            target_arch = "x86",
354            target_vendor = "win7",
355            target_os = "windows",
356            target_env = "msvc"
357        )
358    )),
359    repr(align(16))
360)]
361struct MaskBuffer {
362    buffer: [u8; MASK_BUF_SIZE],
363}
364
365// Use thread local if we can; this is faster even in the single threaded case because
366// it is possible to skip zeroing out the array.
367#[cfg(feature = "std")]
368thread_local! {
369    static MASK_BUF: UnsafeCell<MaskBuffer> =
370        UnsafeCell::new(MaskBuffer { buffer: [0; MASK_BUF_SIZE] });
371}
372
373/// Loops 1 and 2 around the µ-kernel
374///
375/// + app: packed A (A~)
376/// + bpp: packed B (B~)
377/// + nc: columns of packed B
378/// + kc: columns of packed A / rows of packed B
379/// + mc: rows of packed A
380unsafe fn gemm_packed<K>(nc: usize, kc: usize, mc: usize,
381                         alpha: K::Elem,
382                         app: Ptr<*const K::Elem>, bpp: Ptr<*const K::Elem>,
383                         beta: K::Elem,
384                         c: Ptr<*mut K::Elem>, rsc: isize, csc: isize,
385                         tp: ThreadPoolCtx, thread_config: LoopThreadConfig)
386    where K: GemmKernel,
387{
388    let mr = K::MR;
389    let nr = K::NR;
390    // check for the mask buffer that fits 8 x 8 f32 and 8 x 4 f64 kernels and alignment
391    assert!(mr * nr * size_of::<K::Elem>() <= KERNEL_MAX_SIZE && K::align_to() <= KERNEL_MAX_ALIGN);
392
393    #[cfg(not(feature = "std"))]
394    let mut mask_buf = MaskBuffer { buffer: [0; MASK_BUF_SIZE] };
395
396    // LOOP 2: through micropanels in packed `b` (B~, C)
397    range_chunk(nc, nr)
398        .parallel(thread_config.loop2, tp)
399        .thread_local(|_i, _nt| {
400            let mut ptr;
401            #[cfg(not(feature = "std"))]
402            {
403                debug_assert_eq!(_nt, 1);
404                ptr = mask_buf.buffer.as_mut_ptr();
405            }
406            #[cfg(feature = "std")]
407            {
408                ptr = MASK_BUF.with(|buf| (*buf.get()).buffer.as_mut_ptr());
409            }
410            ptr = align_ptr(K::align_to(), ptr);
411            slice::from_raw_parts_mut(ptr as *mut K::Elem, KERNEL_MAX_SIZE / size_of::<K::Elem>())
412        })
413        .for_each(move |_tp, mask_buf, l2, nr_| {
414            let bpp = bpp.stride_offset(1, kc * nr * l2);
415            let c = c.stride_offset(csc, nr * l2);
416
417            // LOOP 1: through micropanels in packed `a` while `b` is constant (A~, C)
418            for (l1, mr_) in range_chunk(mc, mr) {
419                let app = app.stride_offset(1, kc * mr * l1);
420                let c = c.stride_offset(rsc, mr * l1);
421
422                // GEMM KERNEL
423                // NOTE: For the rust kernels, it performs better to simply
424                // always use the masked kernel function!
425                if K::always_masked() || nr_ < nr || mr_ < mr {
426                    masked_kernel::<_, K>(kc, alpha, app.ptr(), bpp.ptr(),
427                                          beta, c.ptr(), rsc, csc,
428                                          mr_, nr_, mask_buf);
429                    continue;
430                } else {
431                    K::kernel(kc, alpha, app.ptr(), bpp.ptr(), beta, c.ptr(), rsc, csc);
432                }
433            }
434        });
435}
436
437/// Allocate a vector of uninitialized data to be used for both packing buffers.
438///
439/// + A~ needs be KC x MC
440/// + B~ needs be KC x NC
441/// but we can make them smaller if the matrix is smaller than this (just ensure
442/// we have rounded up to a multiple of the kernel size).
443///
444/// na: Number of buffers to alloc for A
445///
446/// Return packing buffer and size of A~ (The offset to B~ is A~ size times `na`), size of B~.
447unsafe fn make_packing_buffer<K>(m: usize, k: usize, n: usize, na: usize)
448    -> (Alloc<K::Elem>, usize, usize)
449    where K: GemmKernel,
450{
451    // max alignment requirement is a multiple of min(MR, NR) * sizeof<Elem>
452    // because apack_size is a multiple of MR, start of b aligns fine
453    let m = min(m, K::mc());
454    let k = min(k, K::kc());
455    let n = min(n, K::nc());
456    // round up k, n to multiples of mr, nr
457    // round up to multiple of kc
458    debug_assert_ne!(na, 0);
459    debug_assert!(na <= 128);
460    let apack_size = k * round_up_to(m, K::MR);
461    let bpack_size = k * round_up_to(n, K::NR);
462    let nelem = apack_size * na + bpack_size;
463
464    dprint!("packed nelem={}, apack={}, bpack={},
465             m={} k={} n={}, na={}",
466             nelem, apack_size, bpack_size,
467             m,k,n, na);
468
469    (Alloc::new(nelem, K::align_to()), apack_size, bpack_size)
470}
471
472/// offset the ptr forwards to align to a specific byte count
473/// Safety: align_to must be a power of two and ptr valid for the pointer arithmetic
474#[inline]
475unsafe fn align_ptr<T>(mut align_to: usize, mut ptr: *mut T) -> *mut T {
476    // always ensure minimal alignment on macos
477    if cfg!(target_os = "macos") {
478        align_to = Ord::max(align_to, 8);
479    }
480
481    if align_to != 0 {
482        let cur_align = ptr as usize % align_to;
483        if cur_align != 0 {
484            ptr = ptr.offset(((align_to - cur_align) / size_of::<T>()) as isize);
485        }
486    }
487    ptr
488}
489
490/// Call the GEMM kernel with a "masked" output C.
491/// 
492/// Simply redirect the MR by NR kernel output to the passed
493/// in `mask_buf`, and copy the non masked region to the real
494/// C.
495///
496/// + rows: rows of kernel unmasked
497/// + cols: cols of kernel unmasked
498#[inline(never)]
499unsafe fn masked_kernel<T, K>(k: usize, alpha: T,
500                              a: *const T,
501                              b: *const T,
502                              beta: T,
503                              c: *mut T, rsc: isize, csc: isize,
504                              rows: usize, cols: usize,
505                              mask_buf: &mut [T])
506    where K: GemmKernel<Elem=T>, T: Element,
507{
508    // use column major order for `mask_buf`
509    K::kernel(k, alpha, a, b, T::zero(), mask_buf.as_mut_ptr(), 1, K::MR as isize);
510    c_to_masked_ab_beta_c::<_, K>(beta, c, rsc, csc, rows, cols, &*mask_buf);
511}
512
513/// Copy output in `mask_buf` to the actual c matrix
514///
515/// C ← M + βC  where M is the `mask_buf`
516#[inline]
517unsafe fn c_to_masked_ab_beta_c<T, K>(beta: T,
518                                      c: *mut T, rsc: isize, csc: isize,
519                                      rows: usize, cols: usize,
520                                      mask_buf: &[T])
521    where K: GemmKernel<Elem=T>, T: Element,
522{
523    // note: use separate function here with `&T` argument for mask buf,
524    // so that the compiler sees that `c` and `mask_buf` never alias.
525    let mr = K::MR;
526    let nr = K::NR;
527    let mut ab = mask_buf.as_ptr();
528    for j in 0..nr {
529        for i in 0..mr {
530            if i < rows && j < cols {
531                let cptr = c.stride_offset(rsc, i)
532                            .stride_offset(csc, j);
533                if beta.is_zero() {
534                    *cptr = *ab; // initialize
535                } else {
536                    (*cptr).mul_assign(beta);
537                    (*cptr).add_assign(*ab);
538                }
539            }
540            ab.inc();
541        }
542    }
543}
544
545// Compute just C ← βC
546#[inline(never)]
547unsafe fn c_to_beta_c<T>(m: usize, n: usize, beta: T,
548                         c: *mut T, rsc: isize, csc: isize)
549    where T: Element
550{
551    for i in 0..m {
552        for j in 0..n {
553            let cptr = c.stride_offset(rsc, i)
554                        .stride_offset(csc, j);
555            if beta.is_zero() {
556                *cptr = T::zero(); // initialize C
557            } else {
558                (*cptr).mul_assign(beta);
559            }
560        }
561    }
562}