matrixmultiply/
kernel.rs

1// Copyright 2016 - 2021 Ulrik Sverdrup "bluss"
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::archparam;
10use crate::packing::pack;
11
12/// General matrix multiply kernel
13pub(crate) trait GemmKernel {
14    type Elem: Element;
15
16    /// Kernel rows
17    const MR: usize = Self::MRTy::VALUE;
18    /// Kernel cols
19    const NR: usize = Self::NRTy::VALUE;
20    /// Kernel rows as const num type
21    type MRTy: ConstNum;
22    /// Kernel cols as const num type
23    type NRTy: ConstNum;
24
25    /// align inputs to this
26    fn align_to() -> usize;
27
28    /// Whether to always use the masked wrapper around the kernel.
29    fn always_masked() -> bool;
30
31    // These should ideally be tuned per kernel and per microarch
32    #[inline(always)]
33    fn nc() -> usize { archparam::S_NC }
34    #[inline(always)]
35    fn kc() -> usize { archparam::S_KC }
36    #[inline(always)]
37    fn mc() -> usize { archparam::S_MC }
38
39    /// Pack matrix A into its packing buffer.
40    ///
41    /// See pack for more documentation.
42    ///
43    /// Override only if the default packing function does not
44    /// use the right layout.
45    #[inline]
46    unsafe fn pack_mr(kc: usize, mc: usize, pack_buf: &mut [Self::Elem],
47                      a: *const Self::Elem, rsa: isize, csa: isize)
48    {
49        pack::<Self::MRTy, _>(kc, mc, pack_buf, a, rsa, csa)
50    }
51
52    /// Pack matrix B into its packing buffer
53    ///
54    /// See pack for more documentation.
55    ///
56    /// Override only if the default packing function does not
57    /// use the right layout.
58    #[inline]
59    unsafe fn pack_nr(kc: usize, mc: usize, pack_buf: &mut [Self::Elem],
60                      a: *const Self::Elem, rsa: isize, csa: isize)
61    {
62        pack::<Self::NRTy, _>(kc, mc, pack_buf, a, rsa, csa)
63    }
64
65
66    /// Matrix multiplication kernel
67    ///
68    /// This does the matrix multiplication:
69    ///
70    /// C ← α A B + β C
71    ///
72    /// + `k`: length of data in a, b
73    /// + a, b are packed
74    /// + c has general strides
75    /// + rsc: row stride of c
76    /// + csc: col stride of c
77    /// + `alpha`: scaling factor for A B product
78    /// + `beta`: scaling factor for c.
79    ///   Note: if `beta` is `0.`, the kernel should not (and must not)
80    ///   read from c, its value is to be treated as if it was zero.
81    ///
82    /// When masked, the kernel is always called with β=0 but α is passed
83    /// as usual. (This is only useful information if you return `true` from
84    /// `always_masked`.)
85    unsafe fn kernel(
86        k: usize,
87        alpha: Self::Elem,
88        a: *const Self::Elem,
89        b: *const Self::Elem,
90        beta: Self::Elem,
91        c: *mut Self::Elem, rsc: isize, csc: isize);
92}
93
94pub(crate) trait Element : Copy + Send + Sync {
95    fn zero() -> Self;
96    fn one() -> Self;
97    #[cfg_attr(not(test), allow(unused))]
98    fn test_value() -> Self;
99    fn is_zero(&self) -> bool;
100    fn add_assign(&mut self, rhs: Self);
101    fn mul_assign(&mut self, rhs: Self);
102}
103
104impl Element for f32 {
105    fn zero() -> Self { 0. }
106    fn one() -> Self { 1. }
107    fn test_value() -> Self { 1. }
108    fn is_zero(&self) -> bool { *self == 0. }
109    fn add_assign(&mut self, rhs: Self) { *self += rhs; }
110    fn mul_assign(&mut self, rhs: Self) { *self *= rhs; }
111}
112
113impl Element for f64 {
114    fn zero() -> Self { 0. }
115    fn one() -> Self { 1. }
116    fn test_value() -> Self { 1. }
117    fn is_zero(&self) -> bool { *self == 0. }
118    fn add_assign(&mut self, rhs: Self) { *self += rhs; }
119    fn mul_assign(&mut self, rhs: Self) { *self *= rhs; }
120}
121
122/// Kernel selector
123pub(crate) trait GemmSelect<T> {
124    /// Call `select` with the selected kernel for this configuration
125    fn select<K>(self, kernel: K)
126        where K: GemmKernel<Elem=T>,
127              T: Element;
128}
129
130#[cfg(feature = "cgemm")]
131#[allow(non_camel_case_types)]
132pub(crate) type c32 = [f32; 2];
133
134#[cfg(feature = "cgemm")]
135#[allow(non_camel_case_types)]
136pub(crate) type c64 = [f64; 2];
137
138#[cfg(feature = "cgemm")]
139impl Element for c32 {
140    fn zero() -> Self { [0., 0.] }
141    fn one() -> Self { [1., 0.] }
142    fn test_value() -> Self { [2., 1.] }
143    fn is_zero(&self) -> bool { *self == [0., 0.] }
144
145    #[inline(always)]
146    fn add_assign(&mut self, y: Self) {
147        self[0] += y[0];
148        self[1] += y[1];
149    }
150
151    #[inline(always)]
152    fn mul_assign(&mut self, rhs: Self) {
153        *self = c32_mul(*self, rhs);
154    }
155}
156
157#[cfg(feature = "cgemm")]
158impl Element for c64 {
159    fn zero() -> Self { [0., 0.] }
160    fn one() -> Self { [1., 0.] }
161    fn test_value() -> Self { [2., 1.] }
162    fn is_zero(&self) -> bool { *self == [0., 0.] }
163
164    #[inline(always)]
165    fn add_assign(&mut self, y: Self) {
166        self[0] += y[0];
167        self[1] += y[1];
168    }
169
170    #[inline(always)]
171    fn mul_assign(&mut self, rhs: Self) {
172        *self = c64_mul(*self, rhs);
173    }
174}
175
176#[cfg(feature = "cgemm")]
177#[inline(always)]
178pub(crate) fn c32_mul(x: c32, y: c32) -> c32 {
179    let [a, b] = x;
180    let [c, d] = y;
181    [a * c - b * d, b * c + a * d]
182}
183
184#[cfg(feature = "cgemm")]
185#[inline(always)]
186pub(crate) fn c64_mul(x: c64, y: c64) -> c64 {
187    let [a, b] = x;
188    let [c, d] = y;
189    [a * c - b * d, b * c + a * d]
190}
191
192
193pub(crate) trait ConstNum {
194    const VALUE: usize;
195}
196
197#[cfg(feature = "cgemm")]
198pub(crate) struct U2;
199pub(crate) struct U4;
200pub(crate) struct U8;
201
202#[cfg(feature = "cgemm")]
203impl ConstNum for U2 { const VALUE: usize = 2; }
204impl ConstNum for U4 { const VALUE: usize = 4; }
205impl ConstNum for U8 { const VALUE: usize = 8; }
206
207
208#[cfg(test)]
209pub(crate) mod test {
210    use std::fmt;
211
212    use super::GemmKernel;
213    use super::Element;
214    use crate::aligned_alloc::Alloc;
215
216    pub(crate) fn aligned_alloc<K>(elt: K::Elem, n: usize) -> Alloc<K::Elem>
217        where K: GemmKernel,
218              K::Elem: Copy,
219    {
220        unsafe {
221            Alloc::new(n, K::align_to()).init_with(elt)
222        }
223    }
224
225    /// Assert that we can compute A I == A and I B == B for the kernel (truncated, if needed)
226    ///
227    /// Tests C col major and row major
228    /// Tests beta == 0 (and no other option)
229    pub(crate) fn test_a_kernel<K, T>(_name: &str)
230    where
231        K: GemmKernel<Elem = T>,
232        T: Element + fmt::Debug + PartialEq,
233    {
234        const K: usize = 16;
235        let mr = K::MR;
236        let nr = K::NR;
237
238        // 1. Test A I == A (variables a, b, c)
239        // b looks like an identity matrix (truncated, depending on MR/NR)
240
241        let mut a = aligned_alloc::<K>(T::zero(), mr * K);
242        let mut b = aligned_alloc::<K>(T::zero(), nr * K);
243
244        let mut count = 1;
245        for i in 0..mr {
246            for j in 0..K {
247                for _ in 0..count {
248                    a[i * K + j].add_assign(T::test_value());
249                }
250                count += 1;
251            }
252        }
253
254        for i in 0..Ord::min(K, nr) {
255            b[i + i * nr] = T::one();
256        }
257
258        let mut c = vec![T::zero(); mr * nr];
259        unsafe {
260            // col major C
261            K::kernel(K, T::one(), a.as_ptr(), b.as_ptr(), T::zero(), c.as_mut_ptr(), 1, mr as isize);
262        }
263        let common_len = Ord::min(a.len(), c.len());
264        assert_eq!(&a[..common_len], &c[..common_len]);
265
266        // 2. Test I B == B (variables a, b, c)
267        // a looks like an identity matrix (truncated, depending on MR/NR)
268
269        let mut a = aligned_alloc::<K>(T::zero(), mr * K);
270        let mut b = aligned_alloc::<K>(T::zero(), nr * K);
271
272        for i in 0..Ord::min(K, mr) {
273            a[i + i * mr] = T::one();
274        }
275
276        let mut count = 1;
277        for i in 0..K {
278            for j in 0..nr {
279                for _ in 0..count {
280                    b[i * nr + j].add_assign(T::test_value());
281                }
282                count += 1;
283            }
284        }
285
286        let mut c = vec![T::zero(); mr * nr];
287        unsafe {
288            // row major C
289            K::kernel(K, T::one(), a.as_ptr(), b.as_ptr(), T::zero(), c.as_mut_ptr(), nr as isize, 1);
290        }
291        let common_len = Ord::min(b.len(), c.len());
292        assert_eq!(&b[..common_len], &c[..common_len]);
293    }
294
295    #[cfg(feature="cgemm")]
296    /// Assert that we can compute A I == A for the kernel (truncated, if needed)
297    ///
298    /// Tests C col major and row major
299    /// Tests beta == 0 (and no other option)
300    pub(crate) fn test_complex_packed_kernel<K, T, TReal>(_name: &str)
301    where
302        K: GemmKernel<Elem = T>,
303        T: Element + fmt::Debug + PartialEq,
304        TReal: Element + fmt::Debug + PartialEq,
305    {
306        use crate::cgemm_common::pack_complex;
307
308        const K: usize = 16;
309        let mr = K::MR;
310        let nr = K::NR;
311
312        // 1. Test A I == A (variables a, b, c)
313        // b looks like an identity matrix (truncated, depending on MR/NR)
314
315        let mut a = aligned_alloc::<K>(T::zero(), mr * K);
316        let mut apack = aligned_alloc::<K>(T::zero(), mr * K);
317        let mut b = aligned_alloc::<K>(T::zero(), nr * K);
318        let mut bpack = aligned_alloc::<K>(T::zero(), nr * K);
319
320        let mut count = 1;
321        for i in 0..mr {
322            for j in 0..K {
323                for _ in 0..count {
324                    a[i * K + j].add_assign(T::test_value());
325                }
326                count += 1;
327            }
328        }
329
330        for i in 0..Ord::min(K, nr) {
331            b[i + i * nr] = T::one();
332        }
333
334        // unlike test_a_kernel, we need custom packing for these kernels
335        unsafe {
336            pack_complex::<K::MRTy, T, TReal>(K, mr, &mut apack[..], a.ptr_mut(), 1, mr as isize);
337            pack_complex::<K::NRTy, T, TReal>(nr, K, &mut bpack[..], b.ptr_mut(), nr as isize, 1);
338        }
339
340        let mut c = vec![T::zero(); mr * nr];
341        unsafe {
342            // col major C
343            K::kernel(K, T::one(), apack.as_ptr(), bpack.as_ptr(), T::zero(), c.as_mut_ptr(), 1, mr as isize);
344        }
345        let common_len = Ord::min(a.len(), c.len());
346        assert_eq!(&a[..common_len], &c[..common_len]);
347    }
348
349}