1use crate::kernel::GemmKernel;
10use crate::kernel::GemmSelect;
11use crate::kernel::{U4, U8};
12use crate::archparam;
13
14#[cfg(target_arch="x86")]
15use core::arch::x86::*;
16#[cfg(target_arch="x86_64")]
17use core::arch::x86_64::*;
18#[cfg(any(target_arch="x86", target_arch="x86_64"))]
19use crate::x86::{FusedMulAdd, AvxMulAdd, SMultiplyAdd};
20
21#[cfg(any(target_arch="x86", target_arch="x86_64"))]
22struct KernelAvx;
23#[cfg(any(target_arch="x86", target_arch="x86_64"))]
24struct KernelFmaAvx2;
25#[cfg(any(target_arch="x86", target_arch="x86_64"))]
26struct KernelFma;
27#[cfg(any(target_arch="x86", target_arch="x86_64"))]
28struct KernelSse2;
29
30#[cfg(target_arch="aarch64")]
31#[cfg(has_aarch64_simd)]
32struct KernelNeon;
33struct KernelFallback;
34
35type T = f32;
36
37#[inline]
43pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
44 #[cfg(any(target_arch="x86", target_arch="x86_64"))]
46 {
47 if is_x86_feature_detected_!("fma") {
48 if is_x86_feature_detected_!("avx2") {
49 return selector.select(KernelFmaAvx2);
50 }
51 return selector.select(KernelFma);
52 } else if is_x86_feature_detected_!("avx") {
53 return selector.select(KernelAvx);
54 } else if is_x86_feature_detected_!("sse2") {
55 return selector.select(KernelSse2);
56 }
57 }
58 #[cfg(target_arch="aarch64")]
59 #[cfg(has_aarch64_simd)]
60 {
61 if is_aarch64_feature_detected_!("neon") {
62 return selector.select(KernelNeon);
63 }
64 }
65 return selector.select(KernelFallback);
66}
67
68#[cfg(any(target_arch="x86", target_arch="x86_64"))]
69macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; }
70#[cfg(all(test, any(target_arch="x86", target_arch="x86_64")))]
71macro_rules! loop_n { ($j:ident, $e:expr) => { loop8!($j, $e) }; }
72
73#[cfg(any(target_arch="x86", target_arch="x86_64"))]
74impl GemmKernel for KernelAvx {
75 type Elem = T;
76
77 type MRTy = U8;
78 type NRTy = U8;
79
80 #[inline(always)]
81 fn align_to() -> usize { 32 }
82
83 #[inline(always)]
84 fn always_masked() -> bool { false }
85
86 #[inline(always)]
87 fn nc() -> usize { archparam::S_NC }
88 #[inline(always)]
89 fn kc() -> usize { archparam::S_KC }
90 #[inline(always)]
91 fn mc() -> usize { archparam::S_MC }
92
93 #[inline(always)]
94 unsafe fn kernel(
95 k: usize,
96 alpha: T,
97 a: *const T,
98 b: *const T,
99 beta: T,
100 c: *mut T, rsc: isize, csc: isize) {
101 kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc)
102 }
103}
104
105#[cfg(any(target_arch="x86", target_arch="x86_64"))]
106impl GemmKernel for KernelFma {
107 type Elem = T;
108
109 type MRTy = <KernelAvx as GemmKernel>::MRTy;
110 type NRTy = <KernelAvx as GemmKernel>::NRTy;
111
112 #[inline(always)]
113 fn align_to() -> usize { KernelAvx::align_to() }
114
115 #[inline(always)]
116 fn always_masked() -> bool { KernelAvx::always_masked() }
117
118 #[inline(always)]
119 fn nc() -> usize { archparam::S_NC }
120 #[inline(always)]
121 fn kc() -> usize { archparam::S_KC }
122 #[inline(always)]
123 fn mc() -> usize { archparam::S_MC }
124
125 #[inline(always)]
126 unsafe fn kernel(
127 k: usize,
128 alpha: T,
129 a: *const T,
130 b: *const T,
131 beta: T,
132 c: *mut T, rsc: isize, csc: isize) {
133 kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
134 }
135}
136
137#[cfg(any(target_arch="x86", target_arch="x86_64"))]
138impl GemmKernel for KernelFmaAvx2 {
139 type Elem = T;
140
141 type MRTy = <KernelAvx as GemmKernel>::MRTy;
142 type NRTy = <KernelAvx as GemmKernel>::NRTy;
143
144 #[inline(always)]
145 fn align_to() -> usize { KernelAvx::align_to() }
146
147 #[inline(always)]
148 fn always_masked() -> bool { KernelAvx::always_masked() }
149
150 #[inline(always)]
151 fn nc() -> usize { archparam::S_NC }
152 #[inline(always)]
153 fn kc() -> usize { archparam::S_KC }
154 #[inline(always)]
155 fn mc() -> usize { archparam::S_MC }
156
157 #[inline]
158 unsafe fn pack_mr(kc: usize, mc: usize, pack: &mut [Self::Elem],
159 a: *const Self::Elem, rsa: isize, csa: isize)
160 {
161 crate::packing::pack_avx2::<Self::MRTy, T>(kc, mc, pack, a, rsa, csa)
163 }
164
165 #[inline]
166 unsafe fn pack_nr(kc: usize, mc: usize, pack: &mut [Self::Elem],
167 a: *const Self::Elem, rsa: isize, csa: isize)
168 {
169 crate::packing::pack_avx2::<Self::NRTy, T>(kc, mc, pack, a, rsa, csa)
171 }
172
173 #[inline(always)]
174 unsafe fn kernel(
175 k: usize,
176 alpha: T,
177 a: *const T,
178 b: *const T,
179 beta: T,
180 c: *mut T, rsc: isize, csc: isize) {
181 kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
182 }
183}
184
185#[cfg(any(target_arch="x86", target_arch="x86_64"))]
186impl GemmKernel for KernelSse2 {
187 type Elem = T;
188
189 type MRTy = <KernelFallback as GemmKernel>::MRTy;
190 type NRTy = <KernelFallback as GemmKernel>::NRTy;
191
192 #[inline(always)]
193 fn align_to() -> usize { 16 }
194
195 #[inline(always)]
196 fn always_masked() -> bool { KernelFallback::always_masked() }
197
198 #[inline(always)]
199 fn nc() -> usize { archparam::S_NC }
200 #[inline(always)]
201 fn kc() -> usize { archparam::S_KC }
202 #[inline(always)]
203 fn mc() -> usize { archparam::S_MC }
204
205 #[inline(always)]
206 unsafe fn kernel(
207 k: usize,
208 alpha: T,
209 a: *const T,
210 b: *const T,
211 beta: T,
212 c: *mut T, rsc: isize, csc: isize) {
213 kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc)
214 }
215}
216
217
218#[cfg(target_arch="aarch64")]
219#[cfg(has_aarch64_simd)]
220impl GemmKernel for KernelNeon {
221 type Elem = T;
222
223 type MRTy = U8;
224 type NRTy = U8;
225
226 #[inline(always)]
227 fn align_to() -> usize { 32 }
228
229 #[inline(always)]
230 fn always_masked() -> bool { false }
231
232 #[inline(always)]
233 fn nc() -> usize { archparam::S_NC }
234 #[inline(always)]
235 fn kc() -> usize { archparam::S_KC }
236 #[inline(always)]
237 fn mc() -> usize { archparam::S_MC }
238
239 #[inline(always)]
240 unsafe fn kernel(
241 k: usize,
242 alpha: T,
243 a: *const T,
244 b: *const T,
245 beta: T,
246 c: *mut T, rsc: isize, csc: isize) {
247 kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
248 }
249}
250
251impl GemmKernel for KernelFallback {
252 type Elem = T;
253
254 type MRTy = U8;
255 type NRTy = U4;
256
257 #[inline(always)]
258 fn align_to() -> usize { 0 }
259
260 #[inline(always)]
261 fn always_masked() -> bool { true }
262
263 #[inline(always)]
264 fn nc() -> usize { archparam::S_NC }
265 #[inline(always)]
266 fn kc() -> usize { archparam::S_KC }
267 #[inline(always)]
268 fn mc() -> usize { archparam::S_MC }
269
270 #[inline(always)]
271 unsafe fn kernel(
272 k: usize,
273 alpha: T,
274 a: *const T,
275 b: *const T,
276 beta: T,
277 c: *mut T, rsc: isize, csc: isize) {
278 kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
279 }
280}
281
282#[cfg(any(target_arch="x86", target_arch="x86_64"))]
284#[target_feature(enable="fma")]
285unsafe fn kernel_target_fma(k: usize, alpha: T, a: *const T, b: *const T,
286 beta: T, c: *mut T, rsc: isize, csc: isize)
287{
288 kernel_x86_avx::<FusedMulAdd>(k, alpha, a, b, beta, c, rsc, csc)
289}
290
291#[cfg(any(target_arch="x86", target_arch="x86_64"))]
293#[target_feature(enable="avx")]
294unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T,
295 beta: T, c: *mut T, rsc: isize, csc: isize)
296{
297 kernel_x86_avx::<AvxMulAdd>(k, alpha, a, b, beta, c, rsc, csc)
298}
299
300#[inline]
301#[cfg(any(target_arch="x86", target_arch="x86_64"))]
302#[target_feature(enable="sse2")]
303unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T,
304 beta: T, c: *mut T, rsc: isize, csc: isize)
305{
306 kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
307}
308
309#[inline(always)]
310#[cfg(any(target_arch="x86", target_arch="x86_64"))]
311unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
312 beta: T, c: *mut T, rsc: isize, csc: isize)
313 where MA: SMultiplyAdd,
314{
315 const MR: usize = KernelAvx::MR;
316 const NR: usize = KernelAvx::NR;
317
318 debug_assert_ne!(k, 0);
319
320 let mut ab = [_mm256_setzero_ps(); MR];
321
322 let prefer_row_major_c = rsc != 1;
324
325 let (mut a, mut b) = if prefer_row_major_c { (a, b) } else { (b, a) };
326 let (rsc, csc) = if prefer_row_major_c { (rsc, csc) } else { (csc, rsc) };
327
328 macro_rules! shuffle_mask {
329 ($z:expr, $y:expr, $x:expr, $w:expr) => {
330 ($z << 6) | ($y << 4) | ($x << 2) | $w
331 }
332 }
333 macro_rules! permute_mask {
334 ($z:expr, $y:expr, $x:expr, $w:expr) => {
335 ($z << 6) | ($y << 4) | ($x << 2) | $w
336 }
337 }
338
339 macro_rules! permute2f128_mask {
340 ($y:expr, $x:expr) => {
341 (($y << 4) | $x)
342 }
343 }
344
345 let mut av = _mm256_load_ps(a);
347 let mut bv = _mm256_load_ps(b);
348
349 unroll_by_with_last!(4 => k, is_last, {
351 const PERM32_2301: i32 = permute_mask!(1, 0, 3, 2);
386 const PERM128_30: i32 = permute2f128_mask!(0, 3);
387
388 let a0246 = _mm256_moveldup_ps(av); let a2064 = _mm256_permute_ps(a0246, PERM32_2301);
403
404 let a1357 = _mm256_movehdup_ps(av); let a3175 = _mm256_permute_ps(a1357, PERM32_2301);
406
407 let bv_lh = _mm256_permute2f128_ps(bv, bv, PERM128_30);
408
409 ab[0] = MA::multiply_add(a0246, bv, ab[0]);
410 ab[1] = MA::multiply_add(a2064, bv, ab[1]);
411 ab[2] = MA::multiply_add(a0246, bv_lh, ab[2]);
412 ab[3] = MA::multiply_add(a2064, bv_lh, ab[3]);
413
414 ab[4] = MA::multiply_add(a1357, bv, ab[4]);
415 ab[5] = MA::multiply_add(a3175, bv, ab[5]);
416 ab[6] = MA::multiply_add(a1357, bv_lh, ab[6]);
417 ab[7] = MA::multiply_add(a3175, bv_lh, ab[7]);
418
419 if !is_last {
420 a = a.add(MR);
421 b = b.add(NR);
422
423 bv = _mm256_load_ps(b);
424 av = _mm256_load_ps(a);
425 }
426 });
427
428 let alphav = _mm256_set1_ps(alpha);
429
430 let ab0246 = ab[0];
439 let ab2064 = ab[1];
440 let ab4602 = ab[2]; let ab6420 = ab[3]; let ab1357 = ab[4];
444 let ab3175 = ab[5];
445 let ab5713 = ab[6]; let ab7531 = ab[7]; const SHUF_0123: i32 = shuffle_mask!(3, 2, 1, 0);
449 debug_assert_eq!(SHUF_0123, 0xE4);
450
451 const PERM128_02: i32 = permute2f128_mask!(2, 0);
452 const PERM128_31: i32 = permute2f128_mask!(1, 3);
453
454 let ab0044 = _mm256_shuffle_ps(ab0246, ab2064, SHUF_0123);
468 let ab2266 = _mm256_shuffle_ps(ab2064, ab0246, SHUF_0123);
469
470 let ab4400 = _mm256_shuffle_ps(ab4602, ab6420, SHUF_0123);
471 let ab6622 = _mm256_shuffle_ps(ab6420, ab4602, SHUF_0123);
472
473 let ab1155 = _mm256_shuffle_ps(ab1357, ab3175, SHUF_0123);
474 let ab3377 = _mm256_shuffle_ps(ab3175, ab1357, SHUF_0123);
475
476 let ab5511 = _mm256_shuffle_ps(ab5713, ab7531, SHUF_0123);
477 let ab7733 = _mm256_shuffle_ps(ab7531, ab5713, SHUF_0123);
478
479 let ab0000 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_02);
480 let ab4444 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_31);
481
482 let ab2222 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_02);
483 let ab6666 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_31);
484
485 let ab1111 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_02);
486 let ab5555 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_31);
487
488 let ab3333 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_02);
489 let ab7777 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_31);
490
491 ab[0] = ab0000;
492 ab[1] = ab1111;
493 ab[2] = ab2222;
494 ab[3] = ab3333;
495 ab[4] = ab4444;
496 ab[5] = ab5555;
497 ab[6] = ab6666;
498 ab[7] = ab7777;
499
500 if !MA::IS_FUSED {
503 loop_m!(i, ab[i] = _mm256_mul_ps(alphav, ab[i]));
504 }
505
506 macro_rules! c {
507 ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
508 }
509
510 let mut cv = [_mm256_setzero_ps(); MR];
512 if beta != 0. {
513 let betav = _mm256_set1_ps(beta);
514 if csc == 1 {
516 loop_m!(i, cv[i] = _mm256_loadu_ps(c![i, 0]));
517 } else {
518 loop_m!(i, cv[i] = _mm256_setr_ps(*c![i, 0], *c![i, 1], *c![i, 2], *c![i, 3],
519 *c![i, 4], *c![i, 5], *c![i, 6], *c![i, 7]));
520 }
521 loop_m!(i, cv[i] = _mm256_mul_ps(cv[i], betav));
523 }
524
525 if !MA::IS_FUSED {
527 loop_m!(i, cv[i] = _mm256_add_ps(cv[i], ab[i]));
528 } else {
529 loop_m!(i, cv[i] = MA::multiply_add(alphav, ab[i], cv[i]));
530 }
531
532 if csc == 1 {
534 loop_m!(i, _mm256_storeu_ps(c![i, 0], cv[i]));
535 } else {
536 loop_m!(i, {
538 let cvlo = _mm256_extractf128_ps(cv[i], 0);
539 let cvhi = _mm256_extractf128_ps(cv[i], 1);
540
541 _mm_store_ss(c![i, 0], cvlo);
542 let cperm = _mm_permute_ps(cvlo, permute_mask!(0, 3, 2, 1));
543 _mm_store_ss(c![i, 1], cperm);
544 let cperm = _mm_permute_ps(cperm, permute_mask!(0, 3, 2, 1));
545 _mm_store_ss(c![i, 2], cperm);
546 let cperm = _mm_permute_ps(cperm, permute_mask!(0, 3, 2, 1));
547 _mm_store_ss(c![i, 3], cperm);
548
549 _mm_store_ss(c![i, 4], cvhi);
550 let cperm = _mm_permute_ps(cvhi, permute_mask!(0, 3, 2, 1));
551 _mm_store_ss(c![i, 5], cperm);
552 let cperm = _mm_permute_ps(cperm, permute_mask!(0, 3, 2, 1));
553 _mm_store_ss(c![i, 6], cperm);
554 let cperm = _mm_permute_ps(cperm, permute_mask!(0, 3, 2, 1));
555 _mm_store_ss(c![i, 7], cperm);
556 });
557 }
558}
559
560#[cfg(target_arch="aarch64")]
561#[cfg(has_aarch64_simd)]
562#[target_feature(enable="neon")]
563unsafe fn kernel_target_neon(k: usize, alpha: T, a: *const T, b: *const T,
564 beta: T, c: *mut T, rsc: isize, csc: isize)
565{
566 use core::arch::aarch64::*;
567 const MR: usize = KernelNeon::MR;
568 const NR: usize = KernelNeon::NR;
569
570 let (mut a, mut b, rsc, csc) = if rsc == 1 { (b, a, csc, rsc) } else { (a, b, rsc, csc) };
571
572 let mut ab11 = [vmovq_n_f32(0.); 4];
575 let mut ab12 = [vmovq_n_f32(0.); 4];
576 let mut ab21 = [vmovq_n_f32(0.); 4];
577 let mut ab22 = [vmovq_n_f32(0.); 4];
578
579 macro_rules! ab_ij_equals_ai_bj {
582 ($dest:ident, $av:expr, $bv:expr) => {
583 $dest[0] = vfmaq_laneq_f32($dest[0], $bv, $av, 0);
584 $dest[1] = vfmaq_laneq_f32($dest[1], $bv, $av, 1);
585 $dest[2] = vfmaq_laneq_f32($dest[2], $bv, $av, 2);
586 $dest[3] = vfmaq_laneq_f32($dest[3], $bv, $av, 3);
587 }
588 }
589
590 for _ in 0..k {
591 let a1 = vld1q_f32(a);
592 let b1 = vld1q_f32(b);
593 let a2 = vld1q_f32(a.add(4));
594 let b2 = vld1q_f32(b.add(4));
595
596 ab_ij_equals_ai_bj!(ab11, a1, b1);
605 ab_ij_equals_ai_bj!(ab12, a1, b2);
606 ab_ij_equals_ai_bj!(ab21, a2, b1);
607 ab_ij_equals_ai_bj!(ab22, a2, b2);
608
609 a = a.add(MR);
610 b = b.add(NR);
611 }
612
613 macro_rules! c {
614 ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
615 }
616
617 loop4!(i, ab11[i] = vmulq_n_f32(ab11[i], alpha));
619 loop4!(i, ab12[i] = vmulq_n_f32(ab12[i], alpha));
620 loop4!(i, ab21[i] = vmulq_n_f32(ab21[i], alpha));
621 loop4!(i, ab22[i] = vmulq_n_f32(ab22[i], alpha));
622
623 macro_rules! loadq_from_pointers {
625 ($p0:expr, $p1:expr, $p2:expr, $p3:expr) => (
626 {
627 let v = vld1q_dup_f32($p0);
628 let v = vld1q_lane_f32($p1, v, 1);
629 let v = vld1q_lane_f32($p2, v, 2);
630 let v = vld1q_lane_f32($p3, v, 3);
631 v
632 }
633 );
634 }
635
636 if beta != 0. {
637 let mut c11 = [vmovq_n_f32(0.); 4];
639 let mut c12 = [vmovq_n_f32(0.); 4];
640 let mut c21 = [vmovq_n_f32(0.); 4];
641 let mut c22 = [vmovq_n_f32(0.); 4];
642
643 if csc == 1 {
644 loop4!(i, c11[i] = vld1q_f32(c![i + 0, 0]));
645 loop4!(i, c12[i] = vld1q_f32(c![i + 0, 4]));
646 loop4!(i, c21[i] = vld1q_f32(c![i + 4, 0]));
647 loop4!(i, c22[i] = vld1q_f32(c![i + 4, 4]));
648 } else {
649 loop4!(i, c11[i] = loadq_from_pointers!(c![i + 0, 0], c![i + 0, 1], c![i + 0, 2], c![i + 0, 3]));
650 loop4!(i, c12[i] = loadq_from_pointers!(c![i + 0, 4], c![i + 0, 5], c![i + 0, 6], c![i + 0, 7]));
651 loop4!(i, c21[i] = loadq_from_pointers!(c![i + 4, 0], c![i + 4, 1], c![i + 4, 2], c![i + 4, 3]));
652 loop4!(i, c22[i] = loadq_from_pointers!(c![i + 4, 4], c![i + 4, 5], c![i + 4, 6], c![i + 4, 7]));
653 }
654
655 let betav = vmovq_n_f32(beta);
656
657 loop4!(i, ab11[i] = vfmaq_f32(ab11[i], c11[i], betav));
659 loop4!(i, ab12[i] = vfmaq_f32(ab12[i], c12[i], betav));
660 loop4!(i, ab21[i] = vfmaq_f32(ab21[i], c21[i], betav));
661 loop4!(i, ab22[i] = vfmaq_f32(ab22[i], c22[i], betav));
662 }
663
664 if csc == 1 {
668 loop4!(i, vst1q_f32(c![i + 0, 0], ab11[i]));
669 loop4!(i, vst1q_f32(c![i + 0, 4], ab12[i]));
670 loop4!(i, vst1q_f32(c![i + 4, 0], ab21[i]));
671 loop4!(i, vst1q_f32(c![i + 4, 4], ab22[i]));
672 } else {
673 loop4!(i, vst1q_lane_f32(c![i + 0, 0], ab11[i], 0));
674 loop4!(i, vst1q_lane_f32(c![i + 0, 1], ab11[i], 1));
675 loop4!(i, vst1q_lane_f32(c![i + 0, 2], ab11[i], 2));
676 loop4!(i, vst1q_lane_f32(c![i + 0, 3], ab11[i], 3));
677
678 loop4!(i, vst1q_lane_f32(c![i + 0, 4], ab12[i], 0));
679 loop4!(i, vst1q_lane_f32(c![i + 0, 5], ab12[i], 1));
680 loop4!(i, vst1q_lane_f32(c![i + 0, 6], ab12[i], 2));
681 loop4!(i, vst1q_lane_f32(c![i + 0, 7], ab12[i], 3));
682
683 loop4!(i, vst1q_lane_f32(c![i + 4, 0], ab21[i], 0));
684 loop4!(i, vst1q_lane_f32(c![i + 4, 1], ab21[i], 1));
685 loop4!(i, vst1q_lane_f32(c![i + 4, 2], ab21[i], 2));
686 loop4!(i, vst1q_lane_f32(c![i + 4, 3], ab21[i], 3));
687
688 loop4!(i, vst1q_lane_f32(c![i + 4, 4], ab22[i], 0));
689 loop4!(i, vst1q_lane_f32(c![i + 4, 5], ab22[i], 1));
690 loop4!(i, vst1q_lane_f32(c![i + 4, 6], ab22[i], 2));
691 loop4!(i, vst1q_lane_f32(c![i + 4, 7], ab22[i], 3));
692 }
693}
694
695#[inline]
696unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
697 beta: T, c: *mut T, rsc: isize, csc: isize)
698{
699 const MR: usize = KernelFallback::MR;
700 const NR: usize = KernelFallback::NR;
701 let mut ab: [[T; NR]; MR] = [[0.; NR]; MR];
702 let mut a = a;
703 let mut b = b;
704 debug_assert_eq!(beta, 0., "Beta must be 0 or is not masked");
705
706 unroll_by!(4 => k, {
708 loop8!(i, loop4!(j, ab[i][j] += at(a, i) * at(b, j)));
709
710 a = a.offset(MR as isize);
711 b = b.offset(NR as isize);
712 });
713
714 macro_rules! c {
715 ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
716 }
717
718 loop4!(j, loop8!(i, *c![i, j] = alpha * ab[i][j]));
720}
721
722#[inline(always)]
723unsafe fn at(ptr: *const T, i: usize) -> T {
724 *ptr.offset(i as isize)
725}
726
727#[cfg(test)]
728mod tests {
729 use super::*;
730 use crate::kernel::test::test_a_kernel;
731
732 #[test]
733 fn test_kernel_fallback_impl() {
734 test_a_kernel::<KernelFallback, _>("kernel");
735 }
736
737 #[cfg(any(target_arch="x86", target_arch="x86_64"))]
738 #[test]
739 fn test_loop_m_n() {
740 let mut m = [[0; KernelAvx::NR]; KernelAvx::MR];
741 loop_m!(i, loop_n!(j, m[i][j] += 1));
742 for arr in &m[..] {
743 for elt in &arr[..] {
744 assert_eq!(*elt, 1);
745 }
746 }
747 }
748
749 #[cfg(any(target_arch="aarch64"))]
750 #[cfg(has_aarch64_simd)]
751 mod test_kernel_aarch64 {
752 use super::test_a_kernel;
753 use super::super::*;
754 #[cfg(feature = "std")]
755 use std::println;
756
757 macro_rules! test_arch_kernels_aarch64 {
758 ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
759 $(
760 #[test]
761 fn $name() {
762 if is_aarch64_feature_detected_!($feature_name) {
763 test_a_kernel::<$kernel_ty, _>(stringify!($name));
764 } else {
765 #[cfg(feature = "std")]
766 println!("Skipping, host does not have feature: {:?}", $feature_name);
767 }
768 }
769 )*
770 }
771 }
772
773 test_arch_kernels_aarch64! {
774 "neon", neon8x8, KernelNeon
775 }
776 }
777
778 #[cfg(any(target_arch="x86", target_arch="x86_64"))]
779 mod test_kernel_x86 {
780 use super::test_a_kernel;
781 use super::super::*;
782 #[cfg(feature = "std")]
783 use std::println;
784
785 macro_rules! test_arch_kernels_x86 {
786 ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
787 $(
788 #[test]
789 fn $name() {
790 if is_x86_feature_detected_!($feature_name) {
791 test_a_kernel::<$kernel_ty, _>(stringify!($name));
792 } else {
793 #[cfg(feature = "std")]
794 println!("Skipping, host does not have feature: {:?}", $feature_name);
795 }
796 }
797 )*
798 }
799 }
800
801 test_arch_kernels_x86! {
802 "fma", fma, KernelFma,
803 "avx", avx, KernelAvx,
804 "sse2", sse2, KernelSse2
805 }
806
807 #[test]
808 fn ensure_target_features_tested() {
809 let should_ensure_feature = !option_env!("MMTEST_ENSUREFEATURE")
812 .unwrap_or("").is_empty();
813 if !should_ensure_feature {
814 return;
816 }
817 let feature_name = option_env!("MMTEST_FEATURE")
818 .expect("No MMTEST_FEATURE configured!");
819 let detected = match feature_name {
820 "avx" => is_x86_feature_detected_!("avx"),
821 "fma" => is_x86_feature_detected_!("fma"),
822 "sse2" => is_x86_feature_detected_!("sse2"),
823 _ => false,
824 };
825 assert!(detected, "Feature {:?} was not detected, so it could not be tested",
826 feature_name);
827 }
828 }
829}