1use crate::kernel::GemmKernel;
10use crate::kernel::GemmSelect;
11#[allow(unused)]
12use crate::kernel::{U4, U8};
13use crate::archparam;
14
15#[cfg(target_arch="x86")]
16use core::arch::x86::*;
17#[cfg(target_arch="x86_64")]
18use core::arch::x86_64::*;
19#[cfg(any(target_arch="x86", target_arch="x86_64"))]
20use crate::x86::{FusedMulAdd, AvxMulAdd, DMultiplyAdd};
21
22#[cfg(any(target_arch="x86", target_arch="x86_64"))]
23struct KernelAvx;
24#[cfg(any(target_arch="x86", target_arch="x86_64"))]
25struct KernelFmaAvx2;
26#[cfg(any(target_arch="x86", target_arch="x86_64"))]
27struct KernelFma;
28#[cfg(any(target_arch="x86", target_arch="x86_64"))]
29struct KernelSse2;
30
31#[cfg(target_arch="aarch64")]
32#[cfg(has_aarch64_simd)]
33struct KernelNeon;
34
35struct KernelFallback;
36
37type T = f64;
38
39#[inline]
45pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
46 #[cfg(any(target_arch="x86", target_arch="x86_64"))]
48 {
49 if is_x86_feature_detected_!("fma") {
50 if is_x86_feature_detected_!("avx2") {
51 return selector.select(KernelFmaAvx2);
52 }
53 return selector.select(KernelFma);
54 } else if is_x86_feature_detected_!("avx") {
55 return selector.select(KernelAvx);
56 } else if is_x86_feature_detected_!("sse2") {
57 return selector.select(KernelSse2);
58 }
59 }
60
61 #[cfg(target_arch="aarch64")]
62 #[cfg(has_aarch64_simd)]
63 {
64 if is_aarch64_feature_detected_!("neon") {
65 return selector.select(KernelNeon);
66 }
67 }
68
69 return selector.select(KernelFallback);
70}
71
72
73#[cfg(any(target_arch="x86", target_arch="x86_64"))]
74macro_rules! loop_m {
75 ($i:ident, $e:expr) => { loop8!($i, $e) };
76}
77
78#[cfg(any(target_arch="x86", target_arch="x86_64"))]
79impl GemmKernel for KernelAvx {
80 type Elem = T;
81
82 type MRTy = U8;
83 type NRTy = U4;
84
85 #[inline(always)]
86 fn align_to() -> usize { 32 }
87
88 #[inline(always)]
89 fn always_masked() -> bool { false }
90
91 #[inline(always)]
92 fn nc() -> usize { archparam::D_NC }
93 #[inline(always)]
94 fn kc() -> usize { archparam::D_KC }
95 #[inline(always)]
96 fn mc() -> usize { archparam::D_MC }
97
98 #[inline(always)]
99 unsafe fn kernel(
100 k: usize,
101 alpha: T,
102 a: *const T,
103 b: *const T,
104 beta: T,
105 c: *mut T,
106 rsc: isize,
107 csc: isize)
108 {
109 kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc)
110 }
111}
112
113#[cfg(any(target_arch="x86", target_arch="x86_64"))]
114impl GemmKernel for KernelFma {
115 type Elem = T;
116
117 type MRTy = <KernelAvx as GemmKernel>::MRTy;
118 type NRTy = <KernelAvx as GemmKernel>::NRTy;
119
120 #[inline(always)]
121 fn align_to() -> usize { KernelAvx::align_to() }
122
123 #[inline(always)]
124 fn always_masked() -> bool { KernelAvx::always_masked() }
125
126 #[inline(always)]
127 fn nc() -> usize { archparam::D_NC }
128 #[inline(always)]
129 fn kc() -> usize { archparam::D_KC }
130 #[inline(always)]
131 fn mc() -> usize { archparam::D_MC }
132
133 #[inline(always)]
134 unsafe fn kernel(
135 k: usize,
136 alpha: T,
137 a: *const T,
138 b: *const T,
139 beta: T,
140 c: *mut T,
141 rsc: isize,
142 csc: isize)
143 {
144 kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
145 }
146}
147
148#[cfg(any(target_arch="x86", target_arch="x86_64"))]
149impl GemmKernel for KernelFmaAvx2 {
150 type Elem = T;
151
152 type MRTy = <KernelAvx as GemmKernel>::MRTy;
153 type NRTy = <KernelAvx as GemmKernel>::NRTy;
154
155 #[inline(always)]
156 fn align_to() -> usize { KernelAvx::align_to() }
157
158 #[inline(always)]
159 fn always_masked() -> bool { KernelAvx::always_masked() }
160
161 #[inline(always)]
162 fn nc() -> usize { archparam::D_NC }
163 #[inline(always)]
164 fn kc() -> usize { archparam::D_KC }
165 #[inline(always)]
166 fn mc() -> usize { archparam::D_MC }
167
168 #[inline]
169 unsafe fn pack_mr(kc: usize, mc: usize, pack: &mut [Self::Elem],
170 a: *const Self::Elem, rsa: isize, csa: isize)
171 {
172 crate::packing::pack_avx2::<Self::MRTy, T>(kc, mc, pack, a, rsa, csa)
174 }
175
176 #[inline]
177 unsafe fn pack_nr(kc: usize, mc: usize, pack: &mut [Self::Elem],
178 a: *const Self::Elem, rsa: isize, csa: isize)
179 {
180 crate::packing::pack_avx2::<Self::NRTy, T>(kc, mc, pack, a, rsa, csa)
182 }
183
184
185 #[inline(always)]
186 unsafe fn kernel(
187 k: usize,
188 alpha: T,
189 a: *const T,
190 b: *const T,
191 beta: T,
192 c: *mut T,
193 rsc: isize,
194 csc: isize)
195 {
196 kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
197 }
198}
199
200#[cfg(any(target_arch="x86", target_arch="x86_64"))]
201impl GemmKernel for KernelSse2 {
202 type Elem = T;
203
204 type MRTy = U4;
205 type NRTy = U4;
206
207 #[inline(always)]
208 fn align_to() -> usize { 16 }
209
210 #[inline(always)]
211 fn always_masked() -> bool { true }
212
213 #[inline(always)]
214 fn nc() -> usize { archparam::D_NC }
215 #[inline(always)]
216 fn kc() -> usize { archparam::D_KC }
217 #[inline(always)]
218 fn mc() -> usize { archparam::D_MC }
219
220 #[inline(always)]
221 unsafe fn kernel(
222 k: usize,
223 alpha: T,
224 a: *const T,
225 b: *const T,
226 beta: T,
227 c: *mut T,
228 rsc: isize,
229 csc: isize)
230 {
231 kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc)
232 }
233}
234
235#[cfg(target_arch="aarch64")]
236#[cfg(has_aarch64_simd)]
237impl GemmKernel for KernelNeon {
238 type Elem = T;
239
240 type MRTy = U8;
241 type NRTy = U4;
242
243 #[inline(always)]
244 fn align_to() -> usize { 32 }
245
246 #[inline(always)]
247 fn always_masked() -> bool { false }
248
249 #[inline(always)]
250 fn nc() -> usize { archparam::S_NC }
251 #[inline(always)]
252 fn kc() -> usize { archparam::S_KC }
253 #[inline(always)]
254 fn mc() -> usize { archparam::S_MC }
255
256 #[inline(always)]
257 unsafe fn kernel(
258 k: usize,
259 alpha: T,
260 a: *const T,
261 b: *const T,
262 beta: T,
263 c: *mut T, rsc: isize, csc: isize) {
264 kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
265 }
266}
267
268impl GemmKernel for KernelFallback {
269 type Elem = T;
270
271 type MRTy = U4;
272 type NRTy = U4;
273
274 #[inline(always)]
275 fn align_to() -> usize { 0 }
276
277 #[inline(always)]
278 fn always_masked() -> bool { true }
279
280 #[inline(always)]
281 fn nc() -> usize { archparam::D_NC }
282 #[inline(always)]
283 fn kc() -> usize { archparam::D_KC }
284 #[inline(always)]
285 fn mc() -> usize { archparam::D_MC }
286
287 #[inline(always)]
288 unsafe fn kernel(
289 k: usize,
290 alpha: T,
291 a: *const T,
292 b: *const T,
293 beta: T,
294 c: *mut T,
295 rsc: isize,
296 csc: isize)
297 {
298 kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
299 }
300}
301
302#[cfg(any(target_arch="x86", target_arch="x86_64"))]
304#[target_feature(enable="fma")]
305unsafe fn kernel_target_fma(k: usize, alpha: T, a: *const T, b: *const T,
306 beta: T, c: *mut T, rsc: isize, csc: isize)
307{
308 kernel_x86_avx::<FusedMulAdd>(k, alpha, a, b, beta, c, rsc, csc)
309}
310
311#[cfg(any(target_arch="x86", target_arch="x86_64"))]
313#[target_feature(enable="avx")]
314unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T,
315 beta: T, c: *mut T, rsc: isize, csc: isize)
316{
317 kernel_x86_avx::<AvxMulAdd>(k, alpha, a, b, beta, c, rsc, csc)
318}
319
320#[inline]
321#[target_feature(enable="sse2")]
322#[cfg(any(target_arch="x86", target_arch="x86_64"))]
323unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T,
324 beta: T, c: *mut T, rsc: isize, csc: isize)
325{
326 kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
327}
328
329#[inline(always)]
330#[cfg(any(target_arch="x86", target_arch="x86_64"))]
331unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
332 beta: T, c: *mut T, rsc: isize, csc: isize)
333 where MA: DMultiplyAdd
334{
335 const MR: usize = KernelAvx::MR;
336 const NR: usize = KernelAvx::NR;
337
338 debug_assert_ne!(k, 0);
339
340 let mut ab = [_mm256_setzero_pd(); MR];
341
342 let (mut a, mut b) = (a, b);
343
344 let mut a_0123 = _mm256_load_pd(a);
346 let mut a_4567 = _mm256_load_pd(a.add(4));
347
348 let mut b_0123 = _mm256_load_pd(b);
350
351 unroll_by_with_last!(4 => k, is_last, {
352 let b_1032 = _mm256_shuffle_pd(b_0123, b_0123, 0b0101);
373
374 let b_3210 = _mm256_permute2f128_pd(b_1032, b_1032, 0b0011);
382 let b_2301 = _mm256_shuffle_pd(b_3210, b_3210, 0b0101);
383
384 ab[0] = MA::multiply_add(a_0123, b_0123, ab[0]);
415 ab[1] = MA::multiply_add(a_0123, b_1032, ab[1]);
416 ab[2] = MA::multiply_add(a_0123, b_2301, ab[2]);
417 ab[3] = MA::multiply_add(a_0123, b_3210, ab[3]);
418
419 ab[4] = MA::multiply_add(a_4567, b_0123, ab[4]);
420 ab[5] = MA::multiply_add(a_4567, b_1032, ab[5]);
421 ab[6] = MA::multiply_add(a_4567, b_2301, ab[6]);
422 ab[7] = MA::multiply_add(a_4567, b_3210, ab[7]);
423
424 if !is_last {
425 a = a.add(MR);
426 b = b.add(NR);
427
428 a_0123 = _mm256_load_pd(a);
429 a_4567 = _mm256_load_pd(a.add(4));
430 b_0123 = _mm256_load_pd(b);
431 }
432 });
433
434 if csc == 1 {
603 let a0b0_a0b1_a2b2_a2b3 = _mm256_shuffle_pd(ab[0], ab[1], 0b0000);
607
608 let a1b0_a1b1_a3b2_a3b3 = _mm256_shuffle_pd(ab[1], ab[0], 0b1111);
612
613 let a0b2_a0b3_a2b0_a2b1 = _mm256_shuffle_pd(ab[2], ab[3], 0b0000);
617
618 let a1b2_a1b3_a3b0_a3b1 = _mm256_shuffle_pd(ab[3], ab[2], 0b1111);
622
623 let a4b0_a4b1_a6b2_a6b3 = _mm256_shuffle_pd(ab[4], ab[5], 0b0000);
624 let a5b0_a5b1_a7b2_a7b3 = _mm256_shuffle_pd(ab[5], ab[4], 0b1111);
625
626 let a4b2_a4b3_a6b0_a6b1 = _mm256_shuffle_pd(ab[6], ab[7], 0b0000);
627 let a5b2_a5b3_a7b0_a7b1 = _mm256_shuffle_pd(ab[7], ab[6], 0b1111);
628
629 let a0b0_a0b1_a0b2_a0b3 = _mm256_permute2f128_pd(
658 a0b0_a0b1_a2b2_a2b3,
659 a0b2_a0b3_a2b0_a2b1,
660 0x20
661 );
662 let a2b0_a2b1_a2b2_a2b3 = _mm256_permute2f128_pd(
664 a0b0_a0b1_a2b2_a2b3,
665 a0b2_a0b3_a2b0_a2b1,
666 0x13
667 );
668 let a1b0_a1b1_a1b2_a1b3 = _mm256_permute2f128_pd(
670 a1b0_a1b1_a3b2_a3b3,
671 a1b2_a1b3_a3b0_a3b1,
672 0x20
673 );
674 let a3b0_a3b1_a3b2_a3b3 = _mm256_permute2f128_pd(
676 a1b0_a1b1_a3b2_a3b3,
677 a1b2_a1b3_a3b0_a3b1,
678 0x13
679 );
680
681 let a4b0_a4b1_a4b2_a4b3 = _mm256_permute2f128_pd(
683 a4b0_a4b1_a6b2_a6b3,
684 a4b2_a4b3_a6b0_a6b1,
685 0x20
686 );
687
688 let a6b0_a6b1_a6b2_a6b3 = _mm256_permute2f128_pd(
689 a4b0_a4b1_a6b2_a6b3,
690 a4b2_a4b3_a6b0_a6b1,
691 0x13
692 );
693
694 let a5b0_a5b1_a5b2_a5b3 = _mm256_permute2f128_pd(
695 a5b0_a5b1_a7b2_a7b3,
696 a5b2_a5b3_a7b0_a7b1,
697 0x20
698 );
699
700 let a7b0_a7b1_a7b2_a7b3 = _mm256_permute2f128_pd(
701 a5b0_a5b1_a7b2_a7b3,
702 a5b2_a5b3_a7b0_a7b1,
703 0x13
704 );
705
706 ab[0] = a0b0_a0b1_a0b2_a0b3;
707 ab[1] = a1b0_a1b1_a1b2_a1b3;
708 ab[2] = a2b0_a2b1_a2b2_a2b3;
709 ab[3] = a3b0_a3b1_a3b2_a3b3;
710
711 ab[4] = a4b0_a4b1_a4b2_a4b3;
712 ab[5] = a5b0_a5b1_a5b2_a5b3;
713 ab[6] = a6b0_a6b1_a6b2_a6b3;
714 ab[7] = a7b0_a7b1_a7b2_a7b3;
715
716 } else {
718 let a0b0_a1b0_a2b2_a3b2 = _mm256_blend_pd(ab[0], ab[1], 0b1010);
722 let a0b1_a1b1_a2b3_a3b3 = _mm256_blend_pd(ab[1], ab[0], 0b1010);
724
725 let a0b2_a1b2_a2b0_a3b0 = _mm256_blend_pd(ab[2], ab[3], 0b1010);
729 let a0b3_a1b3_a2b1_a3b1 = _mm256_blend_pd(ab[3], ab[2], 0b1010);
731
732 let a4b0_a5b0_a6b2_a7b2 = _mm256_blend_pd(ab[4], ab[5], 0b1010);
735 let a4b1_a5b1_a6b3_a7b3 = _mm256_blend_pd(ab[5], ab[4], 0b1010);
736
737 let a4b2_a5b2_a6b0_a7b0 = _mm256_blend_pd(ab[6], ab[7], 0b1010);
740 let a4b3_a5b3_a6b1_a7b1 = _mm256_blend_pd(ab[7], ab[6], 0b1010);
741
742 let a0b0_a1b0_a2b0_a3b0 = _mm256_permute2f128_pd(
744 a0b0_a1b0_a2b2_a3b2,
745 a0b2_a1b2_a2b0_a3b0,
746 0x30
747 );
748 let a0b2_a1b2_a2b2_a3b2 = _mm256_permute2f128_pd(
750 a0b0_a1b0_a2b2_a3b2,
751 a0b2_a1b2_a2b0_a3b0,
752 0x12,
753 );
754 let a0b1_a1b1_a2b1_a3b1 = _mm256_permute2f128_pd(
756 a0b1_a1b1_a2b3_a3b3,
757 a0b3_a1b3_a2b1_a3b1,
758 0x30
759 );
760 let a0b3_a1b3_a2b3_a3b3 = _mm256_permute2f128_pd(
762 a0b1_a1b1_a2b3_a3b3,
763 a0b3_a1b3_a2b1_a3b1,
764 0x12
765 );
766
767 let a4b0_a5b0_a6b0_a7b0 = _mm256_permute2f128_pd(
769 a4b0_a5b0_a6b2_a7b2,
770 a4b2_a5b2_a6b0_a7b0,
771 0x30
772 );
773 let a4b2_a5b2_a6b2_a7b2 = _mm256_permute2f128_pd(
774 a4b0_a5b0_a6b2_a7b2,
775 a4b2_a5b2_a6b0_a7b0,
776 0x12,
777 );
778 let a4b1_a5b1_a6b1_a7b1 = _mm256_permute2f128_pd(
779 a4b1_a5b1_a6b3_a7b3,
780 a4b3_a5b3_a6b1_a7b1,
781 0x30
782 );
783 let a4b3_a5b3_a6b3_a7b3 = _mm256_permute2f128_pd(
784 a4b1_a5b1_a6b3_a7b3,
785 a4b3_a5b3_a6b1_a7b1,
786 0x12
787 );
788
789 ab[0] = a0b0_a1b0_a2b0_a3b0;
790 ab[1] = a0b1_a1b1_a2b1_a3b1;
791 ab[2] = a0b2_a1b2_a2b2_a3b2;
792 ab[3] = a0b3_a1b3_a2b3_a3b3;
793
794 ab[4] = a4b0_a5b0_a6b0_a7b0;
795 ab[5] = a4b1_a5b1_a6b1_a7b1;
796 ab[6] = a4b2_a5b2_a6b2_a7b2;
797 ab[7] = a4b3_a5b3_a6b3_a7b3;
798 }
799
800 let alphav = _mm256_broadcast_sd(&alpha);
804 if !MA::IS_FUSED {
805 loop_m!(i, ab[i] = _mm256_mul_pd(alphav, ab[i]));
806 }
807
808 macro_rules! c {
809 ($i:expr, $j:expr) =>
810 (c.offset(rsc * $i as isize + csc * $j as isize));
811 }
812
813 let mut cv = [_mm256_setzero_pd(); MR];
815
816 if beta != 0. {
817 if rsc == 1 {
819 loop4!(i, cv[i] = _mm256_loadu_pd(c![0, i]));
820 loop4!(i, cv[i + 4] = _mm256_loadu_pd(c![4, i]));
821 } else if csc == 1 {
822 loop4!(i, cv[i] = _mm256_loadu_pd(c![i, 0]));
823 loop4!(i, cv[i+4] = _mm256_loadu_pd(c![i+4, 0]));
824 } else {
825 loop4!(i, cv[i] = _mm256_setr_pd(
826 *c![0, i],
827 *c![1, i],
828 *c![2, i],
829 *c![3, i]
830 ));
831 loop4!(i, cv[i + 4] = _mm256_setr_pd(
832 *c![4, i],
833 *c![5, i],
834 *c![6, i],
835 *c![7, i]
836 ));
837 }
838 let beta_v = _mm256_broadcast_sd(&beta);
841 loop_m!(i, cv[i] = _mm256_mul_pd(cv[i], beta_v));
842 }
843
844 if !MA::IS_FUSED {
846 loop_m!(i, cv[i] = _mm256_add_pd(cv[i], ab[i]));
847 } else {
848 loop_m!(i, cv[i] = MA::multiply_add(alphav, ab[i], cv[i]));
849 }
850
851 if rsc == 1 {
852 loop4!(i, _mm256_storeu_pd(c![0, i], cv[i]));
853 loop4!(i, _mm256_storeu_pd(c![4, i], cv[i + 4]));
854 } else if csc == 1 {
855 loop4!(i, _mm256_storeu_pd(c![i, 0], cv[i]));
856 loop4!(i, _mm256_storeu_pd(c![i+4, 0], cv[i + 4]));
857 } else {
858 loop4!(i, {
860 let c_lo: __m128d = _mm256_extractf128_pd(cv[i], 0);
862 let c_hi: __m128d = _mm256_extractf128_pd(cv[i], 1);
864
865 _mm_storel_pd(c![0, i], c_lo);
866 _mm_storeh_pd(c![1, i], c_lo);
867 _mm_storel_pd(c![2, i], c_hi);
868 _mm_storeh_pd(c![3, i], c_hi);
869
870 let c_lo: __m128d = _mm256_extractf128_pd(cv[i+4], 0);
872 let c_hi: __m128d = _mm256_extractf128_pd(cv[i+4], 1);
874
875 _mm_storel_pd(c![4, i], c_lo);
876 _mm_storeh_pd(c![5, i], c_lo);
877 _mm_storel_pd(c![6, i], c_hi);
878 _mm_storeh_pd(c![7, i], c_hi);
879 });
880 }
881}
882
883#[cfg(target_arch="aarch64")]
884#[cfg(has_aarch64_simd)]
885#[target_feature(enable="neon")]
886unsafe fn kernel_target_neon(k: usize, alpha: T, a: *const T, b: *const T,
887 beta: T, c: *mut T, rsc: isize, csc: isize)
888{
889 use core::arch::aarch64::*;
890 const MR: usize = KernelNeon::MR;
891 const NR: usize = KernelNeon::NR;
892
893 let (mut a, mut b) = (a, b);
894
895 let mut ab11 = [vmovq_n_f64(0.); 4];
898 let mut ab12 = [vmovq_n_f64(0.); 4];
899 let mut ab21 = [vmovq_n_f64(0.); 4];
900 let mut ab22 = [vmovq_n_f64(0.); 4];
901
902 macro_rules! ab_ij_equals_ai_bj_12 {
905 ($dest:ident, $av:expr, $bv:expr) => {
906 $dest[0] = vfmaq_laneq_f64($dest[0], $bv, $av, 0);
907 $dest[1] = vfmaq_laneq_f64($dest[1], $bv, $av, 1);
908 }
909 }
910
911 macro_rules! ab_ij_equals_ai_bj_23 {
912 ($dest:ident, $av:expr, $bv:expr) => {
913 $dest[2] = vfmaq_laneq_f64($dest[2], $bv, $av, 0);
914 $dest[3] = vfmaq_laneq_f64($dest[3], $bv, $av, 1);
915 }
916 }
917
918 for _ in 0..k {
919 let b1 = vld1q_f64(b);
920 let b2 = vld1q_f64(b.add(2));
921
922 let a1 = vld1q_f64(a);
923 let a2 = vld1q_f64(a.add(2));
924
925 ab_ij_equals_ai_bj_12!(ab11, a1, b1);
926 ab_ij_equals_ai_bj_23!(ab11, a2, b1);
927 ab_ij_equals_ai_bj_12!(ab12, a1, b2);
928 ab_ij_equals_ai_bj_23!(ab12, a2, b2);
929
930 let a3 = vld1q_f64(a.add(4));
931 let a4 = vld1q_f64(a.add(6));
932
933 ab_ij_equals_ai_bj_12!(ab21, a3, b1);
934 ab_ij_equals_ai_bj_23!(ab21, a4, b1);
935 ab_ij_equals_ai_bj_12!(ab22, a3, b2);
936 ab_ij_equals_ai_bj_23!(ab22, a4, b2);
937
938 a = a.add(MR);
939 b = b.add(NR);
940 }
941
942 macro_rules! c {
943 ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
944 }
945
946 loop4!(i, ab11[i] = vmulq_n_f64(ab11[i], alpha));
948 loop4!(i, ab12[i] = vmulq_n_f64(ab12[i], alpha));
949 loop4!(i, ab21[i] = vmulq_n_f64(ab21[i], alpha));
950 loop4!(i, ab22[i] = vmulq_n_f64(ab22[i], alpha));
951
952 macro_rules! loadq_from_pointers {
954 ($p0:expr, $p1:expr) => (
955 {
956 let v = vld1q_dup_f64($p0);
957 let v = vld1q_lane_f64($p1, v, 1);
958 v
959 }
960 );
961 }
962
963 if beta != 0. {
964 let mut c11 = [vmovq_n_f64(0.); 4];
966 let mut c12 = [vmovq_n_f64(0.); 4];
967 let mut c21 = [vmovq_n_f64(0.); 4];
968 let mut c22 = [vmovq_n_f64(0.); 4];
969
970 if csc == 1 {
971 loop4!(i, c11[i] = vld1q_f64(c![i + 0, 0]));
972 loop4!(i, c12[i] = vld1q_f64(c![i + 0, 2]));
973 loop4!(i, c21[i] = vld1q_f64(c![i + 4, 0]));
974 loop4!(i, c22[i] = vld1q_f64(c![i + 4, 2]));
975 } else {
976 loop4!(i, c11[i] = loadq_from_pointers!(c![i + 0, 0], c![i + 0, 1]));
977 loop4!(i, c12[i] = loadq_from_pointers!(c![i + 0, 2], c![i + 0, 3]));
978 loop4!(i, c21[i] = loadq_from_pointers!(c![i + 4, 0], c![i + 4, 1]));
979 loop4!(i, c22[i] = loadq_from_pointers!(c![i + 4, 2], c![i + 4, 3]));
980 }
981
982 let betav = vmovq_n_f64(beta);
983
984 loop4!(i, ab11[i] = vfmaq_f64(ab11[i], c11[i], betav));
986 loop4!(i, ab12[i] = vfmaq_f64(ab12[i], c12[i], betav));
987 loop4!(i, ab21[i] = vfmaq_f64(ab21[i], c21[i], betav));
988 loop4!(i, ab22[i] = vfmaq_f64(ab22[i], c22[i], betav));
989 }
990
991 if csc == 1 {
995 loop4!(i, vst1q_f64(c![i + 0, 0], ab11[i]));
996 loop4!(i, vst1q_f64(c![i + 0, 2], ab12[i]));
997 loop4!(i, vst1q_f64(c![i + 4, 0], ab21[i]));
998 loop4!(i, vst1q_f64(c![i + 4, 2], ab22[i]));
999 } else {
1000 loop4!(i, vst1q_lane_f64(c![i + 0, 0], ab11[i], 0));
1001 loop4!(i, vst1q_lane_f64(c![i + 0, 1], ab11[i], 1));
1002
1003 loop4!(i, vst1q_lane_f64(c![i + 0, 2], ab12[i], 0));
1004 loop4!(i, vst1q_lane_f64(c![i + 0, 3], ab12[i], 1));
1005
1006 loop4!(i, vst1q_lane_f64(c![i + 4, 0], ab21[i], 0));
1007 loop4!(i, vst1q_lane_f64(c![i + 4, 1], ab21[i], 1));
1008
1009 loop4!(i, vst1q_lane_f64(c![i + 4, 2], ab22[i], 0));
1010 loop4!(i, vst1q_lane_f64(c![i + 4, 3], ab22[i], 1));
1011 }
1012}
1013
1014#[inline]
1015unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
1016 beta: T, c: *mut T, rsc: isize, csc: isize)
1017{
1018 const MR: usize = KernelFallback::MR;
1019 const NR: usize = KernelFallback::NR;
1020 let mut ab: [[T; NR]; MR] = [[0.; NR]; MR];
1021 let mut a = a;
1022 let mut b = b;
1023 debug_assert_eq!(beta, 0., "Beta must be 0 or is not masked");
1024
1025 unroll_by!(4 => k, {
1027 loop4!(i, loop4!(j, ab[i][j] += at(a, i) * at(b, j)));
1028
1029 a = a.offset(MR as isize);
1030 b = b.offset(NR as isize);
1031 });
1032
1033 macro_rules! c {
1034 ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
1035 }
1036
1037 loop4!(j, loop4!(i, *c![i, j] = alpha * ab[i][j]));
1039}
1040
1041#[inline(always)]
1042unsafe fn at(ptr: *const T, i: usize) -> T {
1043 *ptr.offset(i as isize)
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048 use super::*;
1049 use crate::kernel::test::test_a_kernel;
1050
1051 #[test]
1052 fn test_kernel_fallback_impl() {
1053 test_a_kernel::<KernelFallback, _>("kernel");
1054 }
1055
1056 #[cfg(any(target_arch="x86", target_arch="x86_64"))]
1057 #[test]
1058 fn test_loop_m_n() {
1059 let mut m = [[0; 4]; KernelAvx::MR];
1060 loop_m!(i, loop4!(j, m[i][j] += 1));
1061 for arr in &m[..] {
1062 for elt in &arr[..] {
1063 assert_eq!(*elt, 1);
1064 }
1065 }
1066 }
1067
1068 #[cfg(any(target_arch="aarch64"))]
1069 #[cfg(has_aarch64_simd)]
1070 mod test_kernel_aarch64 {
1071 use super::test_a_kernel;
1072 use super::super::*;
1073 #[cfg(feature = "std")]
1074 use std::println;
1075
1076 macro_rules! test_arch_kernels_aarch64 {
1077 ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
1078 $(
1079 #[test]
1080 fn $name() {
1081 if is_aarch64_feature_detected_!($feature_name) {
1082 test_a_kernel::<$kernel_ty, _>(stringify!($name));
1083 } else {
1084 #[cfg(feature = "std")]
1085 println!("Skipping, host does not have feature: {:?}", $feature_name);
1086 }
1087 }
1088 )*
1089 }
1090 }
1091
1092 test_arch_kernels_aarch64! {
1093 "neon", neon, KernelNeon
1094 }
1095 }
1096
1097 #[cfg(any(target_arch="x86", target_arch="x86_64"))]
1098 mod test_kernel_x86 {
1099 use super::test_a_kernel;
1100 use super::super::*;
1101 #[cfg(feature = "std")]
1102 use std::println;
1103 macro_rules! test_arch_kernels_x86 {
1104 ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
1105 $(
1106 #[test]
1107 fn $name() {
1108 if is_x86_feature_detected_!($feature_name) {
1109 test_a_kernel::<$kernel_ty, _>(stringify!($name));
1110 } else {
1111 #[cfg(feature = "std")]
1112 println!("Skipping, host does not have feature: {:?}", $feature_name);
1113 }
1114 }
1115 )*
1116 }
1117 }
1118
1119 test_arch_kernels_x86! {
1120 "fma", fma, KernelFma,
1121 "avx", avx, KernelAvx,
1122 "sse2", sse2, KernelSse2
1123 }
1124 }
1125}