matrixmultiply/
sgemm_kernel.rs

1// Copyright 2016 - 2023 Ulrik Sverdrup "bluss"
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::kernel::GemmKernel;
10use crate::kernel::GemmSelect;
11use crate::kernel::{U4, U8};
12use crate::archparam;
13
14#[cfg(target_arch="x86")]
15use core::arch::x86::*;
16#[cfg(target_arch="x86_64")]
17use core::arch::x86_64::*;
18#[cfg(any(target_arch="x86", target_arch="x86_64"))]
19use crate::x86::{FusedMulAdd, AvxMulAdd, SMultiplyAdd};
20
21#[cfg(any(target_arch="x86", target_arch="x86_64"))]
22struct KernelAvx;
23#[cfg(any(target_arch="x86", target_arch="x86_64"))]
24struct KernelFmaAvx2;
25#[cfg(any(target_arch="x86", target_arch="x86_64"))]
26struct KernelFma;
27#[cfg(any(target_arch="x86", target_arch="x86_64"))]
28struct KernelSse2;
29
30#[cfg(target_arch="aarch64")]
31#[cfg(has_aarch64_simd)]
32struct KernelNeon;
33struct KernelFallback;
34
35type T = f32;
36
37/// Detect which implementation to use and select it using the selector's
38/// .select(Kernel) method.
39///
40/// This function is called one or more times during a whole program's
41/// execution, it may be called for each gemm kernel invocation or fewer times.
42#[inline]
43pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
44    // dispatch to specific compiled versions
45    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
46    {
47        if is_x86_feature_detected_!("fma") {
48            if is_x86_feature_detected_!("avx2") {
49                return selector.select(KernelFmaAvx2);
50            }
51            return selector.select(KernelFma);
52        } else if is_x86_feature_detected_!("avx") {
53            return selector.select(KernelAvx);
54        } else if is_x86_feature_detected_!("sse2") {
55            return selector.select(KernelSse2);
56        }
57    }
58    #[cfg(target_arch="aarch64")]
59    #[cfg(has_aarch64_simd)]
60    {
61        if is_aarch64_feature_detected_!("neon") {
62            return selector.select(KernelNeon);
63        }
64    }
65    return selector.select(KernelFallback);
66}
67
68#[cfg(any(target_arch="x86", target_arch="x86_64"))]
69macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; }
70#[cfg(all(test, any(target_arch="x86", target_arch="x86_64")))]
71macro_rules! loop_n { ($j:ident, $e:expr) => { loop8!($j, $e) }; }
72
73#[cfg(any(target_arch="x86", target_arch="x86_64"))]
74impl GemmKernel for KernelAvx {
75    type Elem = T;
76
77    type MRTy = U8;
78    type NRTy = U8;
79
80    #[inline(always)]
81    fn align_to() -> usize { 32 }
82
83    #[inline(always)]
84    fn always_masked() -> bool { false }
85
86    #[inline(always)]
87    fn nc() -> usize { archparam::S_NC }
88    #[inline(always)]
89    fn kc() -> usize { archparam::S_KC }
90    #[inline(always)]
91    fn mc() -> usize { archparam::S_MC }
92
93    #[inline(always)]
94    unsafe fn kernel(
95        k: usize,
96        alpha: T,
97        a: *const T,
98        b: *const T,
99        beta: T,
100        c: *mut T, rsc: isize, csc: isize) {
101        kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc)
102    }
103}
104
105#[cfg(any(target_arch="x86", target_arch="x86_64"))]
106impl GemmKernel for KernelFma {
107    type Elem = T;
108
109    type MRTy = <KernelAvx as GemmKernel>::MRTy;
110    type NRTy = <KernelAvx as GemmKernel>::NRTy;
111
112    #[inline(always)]
113    fn align_to() -> usize { KernelAvx::align_to() }
114
115    #[inline(always)]
116    fn always_masked() -> bool { KernelAvx::always_masked() }
117
118    #[inline(always)]
119    fn nc() -> usize { archparam::S_NC }
120    #[inline(always)]
121    fn kc() -> usize { archparam::S_KC }
122    #[inline(always)]
123    fn mc() -> usize { archparam::S_MC }
124
125    #[inline(always)]
126    unsafe fn kernel(
127        k: usize,
128        alpha: T,
129        a: *const T,
130        b: *const T,
131        beta: T,
132        c: *mut T, rsc: isize, csc: isize) {
133        kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
134    }
135}
136
137#[cfg(any(target_arch="x86", target_arch="x86_64"))]
138impl GemmKernel for KernelFmaAvx2 {
139    type Elem = T;
140
141    type MRTy = <KernelAvx as GemmKernel>::MRTy;
142    type NRTy = <KernelAvx as GemmKernel>::NRTy;
143
144    #[inline(always)]
145    fn align_to() -> usize { KernelAvx::align_to() }
146
147    #[inline(always)]
148    fn always_masked() -> bool { KernelAvx::always_masked() }
149
150    #[inline(always)]
151    fn nc() -> usize { archparam::S_NC }
152    #[inline(always)]
153    fn kc() -> usize { archparam::S_KC }
154    #[inline(always)]
155    fn mc() -> usize { archparam::S_MC }
156
157    #[inline]
158    unsafe fn pack_mr(kc: usize, mc: usize, pack: &mut [Self::Elem],
159                      a: *const Self::Elem, rsa: isize, csa: isize)
160    {
161        // safety: Avx2 is enabled
162        crate::packing::pack_avx2::<Self::MRTy, T>(kc, mc, pack, a, rsa, csa)
163    }
164
165    #[inline]
166    unsafe fn pack_nr(kc: usize, mc: usize, pack: &mut [Self::Elem],
167                      a: *const Self::Elem, rsa: isize, csa: isize)
168    {
169        // safety: Avx2 is enabled
170        crate::packing::pack_avx2::<Self::NRTy, T>(kc, mc, pack, a, rsa, csa)
171    }
172
173    #[inline(always)]
174    unsafe fn kernel(
175        k: usize,
176        alpha: T,
177        a: *const T,
178        b: *const T,
179        beta: T,
180        c: *mut T, rsc: isize, csc: isize) {
181        kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
182    }
183}
184
185#[cfg(any(target_arch="x86", target_arch="x86_64"))]
186impl GemmKernel for KernelSse2 {
187    type Elem = T;
188
189    type MRTy = <KernelFallback as GemmKernel>::MRTy;
190    type NRTy = <KernelFallback as GemmKernel>::NRTy;
191
192    #[inline(always)]
193    fn align_to() -> usize { 16 }
194
195    #[inline(always)]
196    fn always_masked() -> bool { KernelFallback::always_masked() }
197
198    #[inline(always)]
199    fn nc() -> usize { archparam::S_NC }
200    #[inline(always)]
201    fn kc() -> usize { archparam::S_KC }
202    #[inline(always)]
203    fn mc() -> usize { archparam::S_MC }
204
205    #[inline(always)]
206    unsafe fn kernel(
207        k: usize,
208        alpha: T,
209        a: *const T,
210        b: *const T,
211        beta: T,
212        c: *mut T, rsc: isize, csc: isize) {
213        kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc)
214    }
215}
216
217
218#[cfg(target_arch="aarch64")]
219#[cfg(has_aarch64_simd)]
220impl GemmKernel for KernelNeon {
221    type Elem = T;
222
223    type MRTy = U8;
224    type NRTy = U8;
225
226    #[inline(always)]
227    fn align_to() -> usize { 32 }
228
229    #[inline(always)]
230    fn always_masked() -> bool { false }
231
232    #[inline(always)]
233    fn nc() -> usize { archparam::S_NC }
234    #[inline(always)]
235    fn kc() -> usize { archparam::S_KC }
236    #[inline(always)]
237    fn mc() -> usize { archparam::S_MC }
238
239    #[inline(always)]
240    unsafe fn kernel(
241        k: usize,
242        alpha: T,
243        a: *const T,
244        b: *const T,
245        beta: T,
246        c: *mut T, rsc: isize, csc: isize) {
247        kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
248    }
249}
250
251impl GemmKernel for KernelFallback {
252    type Elem = T;
253
254    type MRTy = U8;
255    type NRTy = U4;
256
257    #[inline(always)]
258    fn align_to() -> usize { 0 }
259
260    #[inline(always)]
261    fn always_masked() -> bool { true }
262
263    #[inline(always)]
264    fn nc() -> usize { archparam::S_NC }
265    #[inline(always)]
266    fn kc() -> usize { archparam::S_KC }
267    #[inline(always)]
268    fn mc() -> usize { archparam::S_MC }
269
270    #[inline(always)]
271    unsafe fn kernel(
272        k: usize,
273        alpha: T,
274        a: *const T,
275        b: *const T,
276        beta: T,
277        c: *mut T, rsc: isize, csc: isize) {
278        kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
279    }
280}
281
282// no inline for unmasked kernels
283#[cfg(any(target_arch="x86", target_arch="x86_64"))]
284#[target_feature(enable="fma")]
285unsafe fn kernel_target_fma(k: usize, alpha: T, a: *const T, b: *const T,
286                            beta: T, c: *mut T, rsc: isize, csc: isize)
287{
288    kernel_x86_avx::<FusedMulAdd>(k, alpha, a, b, beta, c, rsc, csc)
289}
290
291// no inline for unmasked kernels
292#[cfg(any(target_arch="x86", target_arch="x86_64"))]
293#[target_feature(enable="avx")]
294unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T,
295                            beta: T, c: *mut T, rsc: isize, csc: isize)
296{
297    kernel_x86_avx::<AvxMulAdd>(k, alpha, a, b, beta, c, rsc, csc)
298}
299
300#[inline]
301#[cfg(any(target_arch="x86", target_arch="x86_64"))]
302#[target_feature(enable="sse2")]
303unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T,
304                             beta: T, c: *mut T, rsc: isize, csc: isize)
305{
306    kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
307}
308
309#[inline(always)]
310#[cfg(any(target_arch="x86", target_arch="x86_64"))]
311unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
312                             beta: T, c: *mut T, rsc: isize, csc: isize)
313    where MA: SMultiplyAdd,
314{
315    const MR: usize = KernelAvx::MR;
316    const NR: usize = KernelAvx::NR;
317
318    debug_assert_ne!(k, 0);
319
320    let mut ab = [_mm256_setzero_ps(); MR];
321
322    // this kernel can operate in either transposition (C = A B or C^T = B^T A^T)
323    let prefer_row_major_c = rsc != 1;
324
325    let (mut a, mut b) = if prefer_row_major_c { (a, b) } else { (b, a) };
326    let (rsc, csc) = if prefer_row_major_c { (rsc, csc) } else { (csc, rsc) };
327
328    macro_rules! shuffle_mask {
329        ($z:expr, $y:expr, $x:expr, $w:expr) => {
330            ($z << 6) | ($y << 4) | ($x << 2) | $w
331        }
332    }
333    macro_rules! permute_mask {
334        ($z:expr, $y:expr, $x:expr, $w:expr) => {
335            ($z << 6) | ($y << 4) | ($x << 2) | $w
336        }
337    }
338
339    macro_rules! permute2f128_mask {
340        ($y:expr, $x:expr) => {
341            (($y << 4) | $x)
342        }
343    }
344
345    // Start data load before each iteration
346    let mut av = _mm256_load_ps(a);
347    let mut bv = _mm256_load_ps(b);
348
349    // Compute A B
350    unroll_by_with_last!(4 => k, is_last, {
351        // We compute abij = ai bj
352        //
353        // Load b as one contiguous vector
354        // Load a as striped vectors
355        //
356        // Shuffle the abij elements in order after the loop.
357        //
358        // Note this scheme copied and transposed from the BLIS 8x8 sgemm
359        // microkernel.
360        //
361        // Our a indices are striped and our b indices are linear. In
362        // the variable names below, we always have doubled indices so
363        // for example a0246 corresponds to a vector of a0 a0 a2 a2 a4 a4 a6 a6.
364        //
365        // ab0246: ab2064: ab4602: ab6420:
366        // ( ab00  ( ab20  ( ab40  ( ab60
367        //   ab01    ab21    ab41    ab61
368        //   ab22    ab02    ab62    ab42
369        //   ab23    ab03    ab63    ab43
370        //   ab44    ab64    ab04    ab24
371        //   ab45    ab65    ab05    ab25
372        //   ab66    ab46    ab26    ab06
373        //   ab67 )  ab47 )  ab27 )  ab07 )
374        //
375        // ab1357: ab3175: ab5713: ab7531:
376        // ( ab10  ( ab30  ( ab50  ( ab70
377        //   ab11    ab31    ab51    ab71
378        //   ab32    ab12    ab72    ab52
379        //   ab33    ab13    ab73    ab53
380        //   ab54    ab74    ab14    ab34
381        //   ab55    ab75    ab15    ab35
382        //   ab76    ab56    ab36    ab16
383        //   ab77 )  ab57 )  ab37 )  ab17 )
384
385        const PERM32_2301: i32 = permute_mask!(1, 0, 3, 2);
386        const PERM128_30: i32 = permute2f128_mask!(0, 3);
387
388        // _mm256_moveldup_ps(av):
389        // vmovsldup ymm2, ymmword ptr [rax]
390        //
391        // Load and duplicate each even word:
392        // ymm2 ← [a0 a0 a2 a2 a4 a4 a6 a6]
393        //
394        // _mm256_movehdup_ps(av):
395        // vmovshdup ymm2, ymmword ptr [rax]
396        //
397        // Load and duplicate each odd word:
398        // ymm2 ← [a1 a1 a3 a3 a5 a5 a7 a7]
399        //
400
401        let a0246 = _mm256_moveldup_ps(av); // Load: a0 a0 a2 a2 a4 a4 a6 a6
402        let a2064 = _mm256_permute_ps(a0246, PERM32_2301);
403
404        let a1357 = _mm256_movehdup_ps(av); // Load: a1 a1 a3 a3 a5 a5 a7 a7
405        let a3175 = _mm256_permute_ps(a1357, PERM32_2301);
406
407        let bv_lh = _mm256_permute2f128_ps(bv, bv, PERM128_30);
408
409        ab[0] = MA::multiply_add(a0246, bv, ab[0]);
410        ab[1] = MA::multiply_add(a2064, bv, ab[1]);
411        ab[2] = MA::multiply_add(a0246, bv_lh, ab[2]);
412        ab[3] = MA::multiply_add(a2064, bv_lh, ab[3]);
413
414        ab[4] = MA::multiply_add(a1357, bv, ab[4]);
415        ab[5] = MA::multiply_add(a3175, bv, ab[5]);
416        ab[6] = MA::multiply_add(a1357, bv_lh, ab[6]);
417        ab[7] = MA::multiply_add(a3175, bv_lh, ab[7]);
418
419        if !is_last {
420            a = a.add(MR);
421            b = b.add(NR);
422
423            bv = _mm256_load_ps(b);
424            av = _mm256_load_ps(a);
425        }
426    });
427
428    let alphav = _mm256_set1_ps(alpha);
429
430    // Permute to put the abij elements in order
431    //
432    // shufps 0xe4: 22006644 00224466 -> 22226666
433    //
434    // vperm2 0x30: 00004444 44440000 -> 00000000
435    // vperm2 0x12: 00004444 44440000 -> 44444444
436    //
437    
438    let ab0246 = ab[0];
439    let ab2064 = ab[1];
440    let ab4602 = ab[2]; // reverse order
441    let ab6420 = ab[3]; // reverse order
442
443    let ab1357 = ab[4];
444    let ab3175 = ab[5];
445    let ab5713 = ab[6]; // reverse order
446    let ab7531 = ab[7]; // reverse order
447
448    const SHUF_0123: i32 = shuffle_mask!(3, 2, 1, 0);
449    debug_assert_eq!(SHUF_0123, 0xE4);
450
451    const PERM128_02: i32 = permute2f128_mask!(2, 0);
452    const PERM128_31: i32 = permute2f128_mask!(1, 3);
453
454    // No elements are "shuffled" in truth, they all stay at their index
455    // but we combine vectors to de-stripe them.
456    //
457    // For example, the first shuffle below uses 0 1 2 3 which
458    // corresponds to the X0 X1 Y2 Y3 sequence etc:
459    //
460    //                                             variable
461    // X ab00 ab01 ab22 ab23 ab44 ab45 ab66 ab67   ab0246
462    // Y ab20 ab21 ab02 ab03 ab64 ab65 ab46 ab47   ab2064
463    // 
464    //   X0   X1   Y2   Y3   X4   X5   Y6   Y7
465    // = ab00 ab01 ab02 ab03 ab44 ab45 ab46 ab47   ab0044
466
467    let ab0044 = _mm256_shuffle_ps(ab0246, ab2064, SHUF_0123);
468    let ab2266 = _mm256_shuffle_ps(ab2064, ab0246, SHUF_0123);
469
470    let ab4400 = _mm256_shuffle_ps(ab4602, ab6420, SHUF_0123);
471    let ab6622 = _mm256_shuffle_ps(ab6420, ab4602, SHUF_0123);
472
473    let ab1155 = _mm256_shuffle_ps(ab1357, ab3175, SHUF_0123);
474    let ab3377 = _mm256_shuffle_ps(ab3175, ab1357, SHUF_0123);
475
476    let ab5511 = _mm256_shuffle_ps(ab5713, ab7531, SHUF_0123);
477    let ab7733 = _mm256_shuffle_ps(ab7531, ab5713, SHUF_0123);
478
479    let ab0000 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_02);
480    let ab4444 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_31);
481
482    let ab2222 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_02);
483    let ab6666 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_31);
484
485    let ab1111 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_02);
486    let ab5555 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_31);
487
488    let ab3333 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_02);
489    let ab7777 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_31);
490
491    ab[0] = ab0000;
492    ab[1] = ab1111;
493    ab[2] = ab2222;
494    ab[3] = ab3333;
495    ab[4] = ab4444;
496    ab[5] = ab5555;
497    ab[6] = ab6666;
498    ab[7] = ab7777;
499
500    // Compute α (A B)
501    // Compute here if we don't have fma, else pick up α further down
502    if !MA::IS_FUSED {
503        loop_m!(i, ab[i] = _mm256_mul_ps(alphav, ab[i]));
504    }
505
506    macro_rules! c {
507        ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
508    }
509
510    // C ← α A B + β C
511    let mut cv = [_mm256_setzero_ps(); MR];
512    if beta != 0. {
513        let betav = _mm256_set1_ps(beta);
514        // Read C
515        if csc == 1 {
516            loop_m!(i, cv[i] = _mm256_loadu_ps(c![i, 0]));
517        } else {
518            loop_m!(i, cv[i] = _mm256_setr_ps(*c![i, 0], *c![i, 1], *c![i, 2], *c![i, 3],
519                                              *c![i, 4], *c![i, 5], *c![i, 6], *c![i, 7]));
520        }
521        // Compute β C
522        loop_m!(i, cv[i] = _mm256_mul_ps(cv[i], betav));
523    }
524
525    // Compute (α A B) + (β C)
526    if !MA::IS_FUSED {
527        loop_m!(i, cv[i] = _mm256_add_ps(cv[i], ab[i]));
528    } else {
529        loop_m!(i, cv[i] = MA::multiply_add(alphav, ab[i], cv[i]));
530    }
531
532    // Store C back to memory
533    if csc == 1 {
534        loop_m!(i, _mm256_storeu_ps(c![i, 0], cv[i]));
535    } else {
536        // Permute to bring each element in the vector to the front and store
537        loop_m!(i, {
538            let cvlo = _mm256_extractf128_ps(cv[i], 0);
539            let cvhi = _mm256_extractf128_ps(cv[i], 1);
540
541            _mm_store_ss(c![i, 0], cvlo);
542            let cperm = _mm_permute_ps(cvlo, permute_mask!(0, 3, 2, 1));
543            _mm_store_ss(c![i, 1], cperm);
544            let cperm = _mm_permute_ps(cperm, permute_mask!(0, 3, 2, 1));
545            _mm_store_ss(c![i, 2], cperm);
546            let cperm = _mm_permute_ps(cperm, permute_mask!(0, 3, 2, 1));
547            _mm_store_ss(c![i, 3], cperm);
548
549            _mm_store_ss(c![i, 4], cvhi);
550            let cperm = _mm_permute_ps(cvhi, permute_mask!(0, 3, 2, 1));
551            _mm_store_ss(c![i, 5], cperm);
552            let cperm = _mm_permute_ps(cperm, permute_mask!(0, 3, 2, 1));
553            _mm_store_ss(c![i, 6], cperm);
554            let cperm = _mm_permute_ps(cperm, permute_mask!(0, 3, 2, 1));
555            _mm_store_ss(c![i, 7], cperm);
556        });
557    }
558}
559
560#[cfg(target_arch="aarch64")]
561#[cfg(has_aarch64_simd)]
562#[target_feature(enable="neon")]
563unsafe fn kernel_target_neon(k: usize, alpha: T, a: *const T, b: *const T,
564                             beta: T, c: *mut T, rsc: isize, csc: isize)
565{
566    use core::arch::aarch64::*;
567    const MR: usize = KernelNeon::MR;
568    const NR: usize = KernelNeon::NR;
569
570    let (mut a, mut b, rsc, csc) = if rsc == 1 { (b, a, csc, rsc) } else { (a, b, rsc, csc) };
571
572    // Kernel 8 x 8 (a x b)
573    // Four quadrants of 4 x 4
574    let mut ab11 = [vmovq_n_f32(0.); 4];
575    let mut ab12 = [vmovq_n_f32(0.); 4];
576    let mut ab21 = [vmovq_n_f32(0.); 4];
577    let mut ab22 = [vmovq_n_f32(0.); 4];
578
579    // Compute
580    // ab_ij = a_i * b_j for all i, j
581    macro_rules! ab_ij_equals_ai_bj {
582        ($dest:ident, $av:expr, $bv:expr) => {
583            $dest[0] = vfmaq_laneq_f32($dest[0], $bv, $av, 0);
584            $dest[1] = vfmaq_laneq_f32($dest[1], $bv, $av, 1);
585            $dest[2] = vfmaq_laneq_f32($dest[2], $bv, $av, 2);
586            $dest[3] = vfmaq_laneq_f32($dest[3], $bv, $av, 3);
587        }
588    }
589
590    for _ in 0..k {
591        let a1 = vld1q_f32(a);
592        let b1 = vld1q_f32(b);
593        let a2 = vld1q_f32(a.add(4));
594        let b2 = vld1q_f32(b.add(4));
595
596        // compute an outer product ab = a (*) b in four quadrants ab11, ab12, ab21, ab22
597
598        // ab11: [a1 a2 a3 a4] (*) [b1 b2 b3 b4]
599        // ab11: a1b1 a1b2 a1b3 a1b4
600        //       a2b1 a2b2 a2b3 a2b4
601        //       a3b1 a3b2 a3b3 a3b4
602        //       a4b1 a4b2 a4b3 a4b4
603        //  etc
604        ab_ij_equals_ai_bj!(ab11, a1, b1);
605        ab_ij_equals_ai_bj!(ab12, a1, b2);
606        ab_ij_equals_ai_bj!(ab21, a2, b1);
607        ab_ij_equals_ai_bj!(ab22, a2, b2);
608
609        a = a.add(MR);
610        b = b.add(NR);
611    }
612
613    macro_rules! c {
614        ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
615    }
616
617    // ab *= alpha
618    loop4!(i, ab11[i] = vmulq_n_f32(ab11[i], alpha));
619    loop4!(i, ab12[i] = vmulq_n_f32(ab12[i], alpha));
620    loop4!(i, ab21[i] = vmulq_n_f32(ab21[i], alpha));
621    loop4!(i, ab22[i] = vmulq_n_f32(ab22[i], alpha));
622
623    // load one float32x4_t from four pointers
624    macro_rules! loadq_from_pointers {
625        ($p0:expr, $p1:expr, $p2:expr, $p3:expr) => (
626            {
627                let v = vld1q_dup_f32($p0);
628                let v = vld1q_lane_f32($p1, v, 1);
629                let v = vld1q_lane_f32($p2, v, 2);
630                let v = vld1q_lane_f32($p3, v, 3);
631                v
632            }
633        );
634    }
635
636    if beta != 0. {
637        // load existing value in C
638        let mut c11 = [vmovq_n_f32(0.); 4];
639        let mut c12 = [vmovq_n_f32(0.); 4];
640        let mut c21 = [vmovq_n_f32(0.); 4];
641        let mut c22 = [vmovq_n_f32(0.); 4];
642
643        if csc == 1 {
644            loop4!(i, c11[i] = vld1q_f32(c![i + 0, 0]));
645            loop4!(i, c12[i] = vld1q_f32(c![i + 0, 4]));
646            loop4!(i, c21[i] = vld1q_f32(c![i + 4, 0]));
647            loop4!(i, c22[i] = vld1q_f32(c![i + 4, 4]));
648        } else {
649            loop4!(i, c11[i] = loadq_from_pointers!(c![i + 0, 0], c![i + 0, 1], c![i + 0, 2], c![i + 0, 3]));
650            loop4!(i, c12[i] = loadq_from_pointers!(c![i + 0, 4], c![i + 0, 5], c![i + 0, 6], c![i + 0, 7]));
651            loop4!(i, c21[i] = loadq_from_pointers!(c![i + 4, 0], c![i + 4, 1], c![i + 4, 2], c![i + 4, 3]));
652            loop4!(i, c22[i] = loadq_from_pointers!(c![i + 4, 4], c![i + 4, 5], c![i + 4, 6], c![i + 4, 7]));
653        }
654
655        let betav = vmovq_n_f32(beta);
656
657        // ab += β C
658        loop4!(i, ab11[i] = vfmaq_f32(ab11[i], c11[i], betav));
659        loop4!(i, ab12[i] = vfmaq_f32(ab12[i], c12[i], betav));
660        loop4!(i, ab21[i] = vfmaq_f32(ab21[i], c21[i], betav));
661        loop4!(i, ab22[i] = vfmaq_f32(ab22[i], c22[i], betav));
662    }
663
664    // c <- ab
665    // which is in full
666    //   C <- α A B (+ β C)
667    if csc == 1 {
668        loop4!(i, vst1q_f32(c![i + 0, 0], ab11[i]));
669        loop4!(i, vst1q_f32(c![i + 0, 4], ab12[i]));
670        loop4!(i, vst1q_f32(c![i + 4, 0], ab21[i]));
671        loop4!(i, vst1q_f32(c![i + 4, 4], ab22[i]));
672    } else {
673        loop4!(i, vst1q_lane_f32(c![i + 0, 0], ab11[i], 0));
674        loop4!(i, vst1q_lane_f32(c![i + 0, 1], ab11[i], 1));
675        loop4!(i, vst1q_lane_f32(c![i + 0, 2], ab11[i], 2));
676        loop4!(i, vst1q_lane_f32(c![i + 0, 3], ab11[i], 3));
677
678        loop4!(i, vst1q_lane_f32(c![i + 0, 4], ab12[i], 0));
679        loop4!(i, vst1q_lane_f32(c![i + 0, 5], ab12[i], 1));
680        loop4!(i, vst1q_lane_f32(c![i + 0, 6], ab12[i], 2));
681        loop4!(i, vst1q_lane_f32(c![i + 0, 7], ab12[i], 3));
682
683        loop4!(i, vst1q_lane_f32(c![i + 4, 0], ab21[i], 0));
684        loop4!(i, vst1q_lane_f32(c![i + 4, 1], ab21[i], 1));
685        loop4!(i, vst1q_lane_f32(c![i + 4, 2], ab21[i], 2));
686        loop4!(i, vst1q_lane_f32(c![i + 4, 3], ab21[i], 3));
687
688        loop4!(i, vst1q_lane_f32(c![i + 4, 4], ab22[i], 0));
689        loop4!(i, vst1q_lane_f32(c![i + 4, 5], ab22[i], 1));
690        loop4!(i, vst1q_lane_f32(c![i + 4, 6], ab22[i], 2));
691        loop4!(i, vst1q_lane_f32(c![i + 4, 7], ab22[i], 3));
692    }
693}
694
695#[inline]
696unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
697                               beta: T, c: *mut T, rsc: isize, csc: isize)
698{
699    const MR: usize = KernelFallback::MR;
700    const NR: usize = KernelFallback::NR;
701    let mut ab: [[T; NR]; MR] = [[0.; NR]; MR];
702    let mut a = a;
703    let mut b = b;
704    debug_assert_eq!(beta, 0., "Beta must be 0 or is not masked");
705
706    // Compute A B into ab[i][j]
707    unroll_by!(4 => k, {
708        loop8!(i, loop4!(j, ab[i][j] += at(a, i) * at(b, j)));
709
710        a = a.offset(MR as isize);
711        b = b.offset(NR as isize);
712    });
713
714    macro_rules! c {
715        ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
716    }
717
718    // set C = α A B
719    loop4!(j, loop8!(i, *c![i, j] = alpha * ab[i][j]));
720}
721
722#[inline(always)]
723unsafe fn at(ptr: *const T, i: usize) -> T {
724    *ptr.offset(i as isize)
725}
726
727#[cfg(test)]
728mod tests {
729    use super::*;
730    use crate::kernel::test::test_a_kernel;
731
732    #[test]
733    fn test_kernel_fallback_impl() {
734        test_a_kernel::<KernelFallback, _>("kernel");
735    }
736
737    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
738    #[test]
739    fn test_loop_m_n() {
740        let mut m = [[0; KernelAvx::NR]; KernelAvx::MR];
741        loop_m!(i, loop_n!(j, m[i][j] += 1));
742        for arr in &m[..] {
743            for elt in &arr[..] {
744                assert_eq!(*elt, 1);
745            }
746        }
747    }
748
749    #[cfg(any(target_arch="aarch64"))]
750    #[cfg(has_aarch64_simd)]
751    mod test_kernel_aarch64 {
752        use super::test_a_kernel;
753        use super::super::*;
754        #[cfg(feature = "std")]
755        use std::println;
756
757        macro_rules! test_arch_kernels_aarch64 {
758            ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
759                $(
760                #[test]
761                fn $name() {
762                    if is_aarch64_feature_detected_!($feature_name) {
763                        test_a_kernel::<$kernel_ty, _>(stringify!($name));
764                    } else {
765                        #[cfg(feature = "std")]
766                        println!("Skipping, host does not have feature: {:?}", $feature_name);
767                    }
768                }
769                )*
770            }
771        }
772
773        test_arch_kernels_aarch64! {
774            "neon", neon8x8, KernelNeon
775        }
776    }
777
778    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
779    mod test_kernel_x86 {
780        use super::test_a_kernel;
781        use super::super::*;
782        #[cfg(feature = "std")]
783        use std::println;
784
785        macro_rules! test_arch_kernels_x86 {
786            ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
787                $(
788                #[test]
789                fn $name() {
790                    if is_x86_feature_detected_!($feature_name) {
791                        test_a_kernel::<$kernel_ty, _>(stringify!($name));
792                    } else {
793                        #[cfg(feature = "std")]
794                        println!("Skipping, host does not have feature: {:?}", $feature_name);
795                    }
796                }
797                )*
798            }
799        }
800
801        test_arch_kernels_x86! {
802            "fma", fma, KernelFma,
803            "avx", avx, KernelAvx,
804            "sse2", sse2, KernelSse2
805        }
806
807        #[test]
808        fn ensure_target_features_tested() {
809            // If enabled, this test ensures that the requested feature actually
810            // was enabled on this configuration, so that it was tested.
811            let should_ensure_feature = !option_env!("MMTEST_ENSUREFEATURE")
812                                                    .unwrap_or("").is_empty();
813            if !should_ensure_feature {
814                // skip
815                return;
816            }
817            let feature_name = option_env!("MMTEST_FEATURE")
818                                          .expect("No MMTEST_FEATURE configured!");
819            let detected = match feature_name {
820                "avx" => is_x86_feature_detected_!("avx"),
821                "fma" => is_x86_feature_detected_!("fma"),
822                "sse2" => is_x86_feature_detected_!("sse2"),
823                _ => false,
824            };
825            assert!(detected, "Feature {:?} was not detected, so it could not be tested",
826                    feature_name);
827        }
828    }
829}