matrixmultiply/
dgemm_kernel.rs

1// Copyright 2016 - 2023 Ulrik Sverdrup "bluss"
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::kernel::GemmKernel;
10use crate::kernel::GemmSelect;
11#[allow(unused)]
12use crate::kernel::{U4, U8};
13use crate::archparam;
14
15#[cfg(target_arch="x86")]
16use core::arch::x86::*;
17#[cfg(target_arch="x86_64")]
18use core::arch::x86_64::*;
19#[cfg(any(target_arch="x86", target_arch="x86_64"))]
20use crate::x86::{FusedMulAdd, AvxMulAdd, DMultiplyAdd};
21
22#[cfg(any(target_arch="x86", target_arch="x86_64"))]
23struct KernelAvx;
24#[cfg(any(target_arch="x86", target_arch="x86_64"))]
25struct KernelFmaAvx2;
26#[cfg(any(target_arch="x86", target_arch="x86_64"))]
27struct KernelFma;
28#[cfg(any(target_arch="x86", target_arch="x86_64"))]
29struct KernelSse2;
30
31#[cfg(target_arch="aarch64")]
32#[cfg(has_aarch64_simd)]
33struct KernelNeon;
34
35struct KernelFallback;
36
37type T = f64;
38
39/// Detect which implementation to use and select it using the selector's
40/// .select(Kernel) method.
41///
42/// This function is called one or more times during a whole program's
43/// execution, it may be called for each gemm kernel invocation or fewer times.
44#[inline]
45pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
46    // dispatch to specific compiled versions
47    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
48    {
49        if is_x86_feature_detected_!("fma") {
50            if is_x86_feature_detected_!("avx2") {
51                return selector.select(KernelFmaAvx2);
52            }
53            return selector.select(KernelFma);
54        } else if is_x86_feature_detected_!("avx") {
55            return selector.select(KernelAvx);
56        } else if is_x86_feature_detected_!("sse2") {
57            return selector.select(KernelSse2);
58        }
59    }
60
61    #[cfg(target_arch="aarch64")]
62    #[cfg(has_aarch64_simd)]
63    {
64        if is_aarch64_feature_detected_!("neon") {
65            return selector.select(KernelNeon);
66        }
67    }
68
69    return selector.select(KernelFallback);
70}
71
72
73#[cfg(any(target_arch="x86", target_arch="x86_64"))]
74macro_rules! loop_m {
75    ($i:ident, $e:expr) => { loop8!($i, $e) };
76}
77
78#[cfg(any(target_arch="x86", target_arch="x86_64"))]
79impl GemmKernel for KernelAvx {
80    type Elem = T;
81
82    type MRTy = U8;
83    type NRTy = U4;
84
85    #[inline(always)]
86    fn align_to() -> usize { 32 }
87
88    #[inline(always)]
89    fn always_masked() -> bool { false }
90
91    #[inline(always)]
92    fn nc() -> usize { archparam::D_NC }
93    #[inline(always)]
94    fn kc() -> usize { archparam::D_KC }
95    #[inline(always)]
96    fn mc() -> usize { archparam::D_MC }
97
98    #[inline(always)]
99    unsafe fn kernel(
100        k: usize,
101        alpha: T,
102        a: *const T,
103        b: *const T,
104        beta: T,
105        c: *mut T,
106        rsc: isize,
107        csc: isize)
108    {
109        kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc)
110    }
111}
112
113#[cfg(any(target_arch="x86", target_arch="x86_64"))]
114impl GemmKernel for KernelFma {
115    type Elem = T;
116
117    type MRTy = <KernelAvx as GemmKernel>::MRTy;
118    type NRTy = <KernelAvx as GemmKernel>::NRTy;
119
120    #[inline(always)]
121    fn align_to() -> usize { KernelAvx::align_to() }
122
123    #[inline(always)]
124    fn always_masked() -> bool { KernelAvx::always_masked() }
125
126    #[inline(always)]
127    fn nc() -> usize { archparam::D_NC }
128    #[inline(always)]
129    fn kc() -> usize { archparam::D_KC }
130    #[inline(always)]
131    fn mc() -> usize { archparam::D_MC }
132
133    #[inline(always)]
134    unsafe fn kernel(
135        k: usize,
136        alpha: T,
137        a: *const T,
138        b: *const T,
139        beta: T,
140        c: *mut T,
141        rsc: isize,
142        csc: isize)
143    {
144        kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
145    }
146}
147
148#[cfg(any(target_arch="x86", target_arch="x86_64"))]
149impl GemmKernel for KernelFmaAvx2 {
150    type Elem = T;
151
152    type MRTy = <KernelAvx as GemmKernel>::MRTy;
153    type NRTy = <KernelAvx as GemmKernel>::NRTy;
154
155    #[inline(always)]
156    fn align_to() -> usize { KernelAvx::align_to() }
157
158    #[inline(always)]
159    fn always_masked() -> bool { KernelAvx::always_masked() }
160
161    #[inline(always)]
162    fn nc() -> usize { archparam::D_NC }
163    #[inline(always)]
164    fn kc() -> usize { archparam::D_KC }
165    #[inline(always)]
166    fn mc() -> usize { archparam::D_MC }
167
168    #[inline]
169    unsafe fn pack_mr(kc: usize, mc: usize, pack: &mut [Self::Elem],
170                      a: *const Self::Elem, rsa: isize, csa: isize)
171    {
172        // safety: Avx2 is enabled
173        crate::packing::pack_avx2::<Self::MRTy, T>(kc, mc, pack, a, rsa, csa)
174    }
175
176    #[inline]
177    unsafe fn pack_nr(kc: usize, mc: usize, pack: &mut [Self::Elem],
178                      a: *const Self::Elem, rsa: isize, csa: isize)
179    {
180        // safety: Avx2 is enabled
181        crate::packing::pack_avx2::<Self::NRTy, T>(kc, mc, pack, a, rsa, csa)
182    }
183
184
185    #[inline(always)]
186    unsafe fn kernel(
187        k: usize,
188        alpha: T,
189        a: *const T,
190        b: *const T,
191        beta: T,
192        c: *mut T,
193        rsc: isize,
194        csc: isize)
195    {
196        kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
197    }
198}
199
200#[cfg(any(target_arch="x86", target_arch="x86_64"))]
201impl GemmKernel for KernelSse2 {
202    type Elem = T;
203
204    type MRTy = U4;
205    type NRTy = U4;
206
207    #[inline(always)]
208    fn align_to() -> usize { 16 }
209
210    #[inline(always)]
211    fn always_masked() -> bool { true }
212
213    #[inline(always)]
214    fn nc() -> usize { archparam::D_NC }
215    #[inline(always)]
216    fn kc() -> usize { archparam::D_KC }
217    #[inline(always)]
218    fn mc() -> usize { archparam::D_MC }
219
220    #[inline(always)]
221    unsafe fn kernel(
222        k: usize,
223        alpha: T,
224        a: *const T,
225        b: *const T,
226        beta: T,
227        c: *mut T,
228        rsc: isize,
229        csc: isize)
230    {
231        kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc)
232    }
233}
234
235#[cfg(target_arch="aarch64")]
236#[cfg(has_aarch64_simd)]
237impl GemmKernel for KernelNeon {
238    type Elem = T;
239
240    type MRTy = U8;
241    type NRTy = U4;
242
243    #[inline(always)]
244    fn align_to() -> usize { 32 }
245
246    #[inline(always)]
247    fn always_masked() -> bool { false }
248
249    #[inline(always)]
250    fn nc() -> usize { archparam::S_NC }
251    #[inline(always)]
252    fn kc() -> usize { archparam::S_KC }
253    #[inline(always)]
254    fn mc() -> usize { archparam::S_MC }
255
256    #[inline(always)]
257    unsafe fn kernel(
258        k: usize,
259        alpha: T,
260        a: *const T,
261        b: *const T,
262        beta: T,
263        c: *mut T, rsc: isize, csc: isize) {
264        kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
265    }
266}
267
268impl GemmKernel for KernelFallback {
269    type Elem = T;
270
271    type MRTy = U4;
272    type NRTy = U4;
273
274    #[inline(always)]
275    fn align_to() -> usize { 0 }
276
277    #[inline(always)]
278    fn always_masked() -> bool { true }
279
280    #[inline(always)]
281    fn nc() -> usize { archparam::D_NC }
282    #[inline(always)]
283    fn kc() -> usize { archparam::D_KC }
284    #[inline(always)]
285    fn mc() -> usize { archparam::D_MC }
286
287    #[inline(always)]
288    unsafe fn kernel(
289        k: usize,
290        alpha: T,
291        a: *const T,
292        b: *const T,
293        beta: T,
294        c: *mut T,
295        rsc: isize,
296        csc: isize)
297    {
298        kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
299    }
300}
301
302// no inline for unmasked kernels
303#[cfg(any(target_arch="x86", target_arch="x86_64"))]
304#[target_feature(enable="fma")]
305unsafe fn kernel_target_fma(k: usize, alpha: T, a: *const T, b: *const T,
306                            beta: T, c: *mut T, rsc: isize, csc: isize)
307{
308    kernel_x86_avx::<FusedMulAdd>(k, alpha, a, b, beta, c, rsc, csc)
309}
310
311// no inline for unmasked kernels
312#[cfg(any(target_arch="x86", target_arch="x86_64"))]
313#[target_feature(enable="avx")]
314unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T,
315                            beta: T, c: *mut T, rsc: isize, csc: isize)
316{
317    kernel_x86_avx::<AvxMulAdd>(k, alpha, a, b, beta, c, rsc, csc)
318}
319
320#[inline]
321#[target_feature(enable="sse2")]
322#[cfg(any(target_arch="x86", target_arch="x86_64"))]
323unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T,
324                                 beta: T, c: *mut T, rsc: isize, csc: isize)
325{
326    kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
327}
328
329#[inline(always)]
330#[cfg(any(target_arch="x86", target_arch="x86_64"))]
331unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
332                             beta: T, c: *mut T, rsc: isize, csc: isize)
333    where MA: DMultiplyAdd
334{
335    const MR: usize = KernelAvx::MR;
336    const NR: usize = KernelAvx::NR;
337
338    debug_assert_ne!(k, 0);
339
340    let mut ab = [_mm256_setzero_pd(); MR];
341
342    let (mut a, mut b) = (a, b);
343
344    // With MR=8, we load sets of 4 doubles from a
345    let mut a_0123 = _mm256_load_pd(a);
346    let mut a_4567 = _mm256_load_pd(a.add(4));
347
348    // With NR=4, we load 4 doubles from b
349    let mut b_0123 = _mm256_load_pd(b);
350
351    unroll_by_with_last!(4 => k, is_last, {
352        // We need to multiply each element of b with each element of a_0
353        // and a_1. To do so, we need to generate all possible permutations
354        // for the doubles in b, but without two permutations having the
355        // same double at the same spot.
356        //
357        // So, if we are given the permutations (indices of the doubles
358        // in the packed 4-vector):
359        //
360        // 0 1 2 3
361        //
362        // Then another valid permutation has to shuffle all elements
363        // around without a single element remaining at the same index
364        // it was before.
365        //
366        // A possible set of valid combination then are:
367        //
368        // 0 1 2 3 (the original)
369        // 1 0 3 2 (chosen because of _mm256_shuffle_pd)
370        // 3 2 1 0 (chosen because of _mm256_permute2f128_pd)
371        // 2 3 0 1 (chosen because of _mm256_shuffle_pd)
372        let b_1032 = _mm256_shuffle_pd(b_0123, b_0123, 0b0101);
373
374        // Both packed 4-vectors are the same, so one could also perform
375        // the selection 0b0000_0001 or 0b0010_0001 or 0b0010_0011.
376        // The confusing part is that of the lower 4 bits and upper 4 bits
377        // only 2 bits are used in each. The same choice could have been
378        // encoded in a nibble (4 bits) total, i.e. 0b1100, had the intrinsics
379        // been defined differently. The highest bit in each nibble controls
380        // zero-ing behaviour though.
381        let b_3210 = _mm256_permute2f128_pd(b_1032, b_1032, 0b0011);
382        let b_2301 = _mm256_shuffle_pd(b_3210, b_3210, 0b0101);
383
384        // The ideal distribution of a_i b_j pairs in the resulting panel of
385        // c in order to have the matching products / sums of products in the
386        // right places would look like this after the first iteration:
387        //
388        // ab_0 || a0 b0 | a0 b1 | a0 b2 | a0 b3
389        // ab_1 || a1 b0 | a1 b1 | a1 b2 | a1 b3
390        // ab_2 || a2 b0 | a2 b1 | a2 b2 | a2 b3
391        // ab_3 || a3 b0 | a3 b1 | a3 b2 | a3 b3
392        //      || -----------------------------
393        // ab_4 || a4 b0 | a4 b1 | a4 b2 | a4 b3
394        // ab_5 || a5 b0 | a5 b1 | a5 b2 | a5 b3
395        // ab_6 || a6 b0 | a6 b1 | a6 b2 | a6 b3
396        // ab_7 || a7 b0 | a7 b1 | a7 b2 | a7 b3
397        //
398        // As this is not possible / would require too many extra variables
399        // and thus operations, we get the following configuration, and thus
400        // have to be smart about putting the correct values into their
401        // respective places at the end.
402        //
403        // ab_0 || a0 b0 | a1 b1 | a2 b2 | a3 b3
404        // ab_1 || a0 b1 | a1 b0 | a2 b3 | a3 b2
405        // ab_2 || a0 b2 | a1 b3 | a2 b0 | a3 b1
406        // ab_3 || a0 b3 | a1 b2 | a2 b1 | a3 b0
407        //      || -----------------------------
408        // ab_4 || a4 b0 | a5 b1 | a6 b2 | a7 b3
409        // ab_5 || a4 b1 | a5 b0 | a6 b3 | a7 b2
410        // ab_6 || a4 b2 | a5 b3 | a6 b0 | a7 b1
411        // ab_7 || a4 b3 | a5 b2 | a6 b1 | a7 b0
412
413        // Add and multiply in one go
414        ab[0] = MA::multiply_add(a_0123, b_0123, ab[0]);
415        ab[1] = MA::multiply_add(a_0123, b_1032, ab[1]);
416        ab[2] = MA::multiply_add(a_0123, b_2301, ab[2]);
417        ab[3] = MA::multiply_add(a_0123, b_3210, ab[3]);
418
419        ab[4] = MA::multiply_add(a_4567, b_0123, ab[4]);
420        ab[5] = MA::multiply_add(a_4567, b_1032, ab[5]);
421        ab[6] = MA::multiply_add(a_4567, b_2301, ab[6]);
422        ab[7] = MA::multiply_add(a_4567, b_3210, ab[7]);
423
424        if !is_last {
425            a = a.add(MR);
426            b = b.add(NR);
427
428            a_0123 = _mm256_load_pd(a);
429            a_4567 = _mm256_load_pd(a.add(4));
430            b_0123 = _mm256_load_pd(b);
431        }
432    });
433
434    // Our products/sums are currently stored according to the
435    // table below. Each row corresponds to one packed simd
436    // 4-vector.
437    //
438    // ab_0 || a0 b0 | a1 b1 | a2 b2 | a3 b3
439    // ab_1 || a0 b1 | a1 b0 | a2 b3 | a3 b2
440    // ab_2 || a0 b2 | a1 b3 | a2 b0 | a3 b1
441    // ab_3 || a0 b3 | a1 b2 | a2 b1 | a3 b0
442    //      || -----------------------------
443    // ab_4 || a4 b0 | a5 b1 | a6 b2 | a7 b3
444    // ab_5 || a4 b1 | a5 b0 | a6 b3 | a7 b2
445    // ab_6 || a4 b2 | a5 b3 | a6 b0 | a7 b1
446    // ab_7 || a4 b3 | a5 b2 | a6 b1 | a7 b0
447    //
448    // This is the final results, where indices are stored
449    // in their proper location.
450    //
451    //      || a0 b0 | a0 b1 | a0 b2 | a0 b3
452    //      || a1 b0 | a1 b1 | a1 b2 | a1 b3
453    //      || a2 b0 | a2 b1 | a2 b2 | a2 b3
454    //      || a3 b0 | a3 b1 | a3 b2 | a3 b3
455    //      || -----------------------------
456    //      || a4 b0 | a4 b1 | a4 b2 | a4 b3
457    //      || a5 b0 | a5 b1 | a5 b2 | a5 b3
458    //      || a6 b0 | a6 b1 | a6 b2 | a6 b3
459    //      || a7 b0 | a7 b1 | a7 b2 | a7 b3
460    //
461    // Given the simd intrinsics available through avx, we have two
462    // ways of achieving this format. By either:
463    //
464    // a) Creating packed 4-vectors of rows, or
465    // b) creating packed 4-vectors of columns.
466    //
467    // ** We will use option a) because it has slightly cheaper throughput
468    // characteristics (see below).
469    //
470    // # a) Creating packed 4-vectors of columns
471    //
472    // To create packed 4-vectors of columns, we make us of
473    // _mm256_blend_pd operations, followed by _mm256_permute2f128_pd.
474    //
475    // The first operation has latency 1 (all architectures), and 0.33
476    // throughput (Skylake, Broadwell, Haswell), or 0.5 (Ivy Bridge).
477    //
478    // The second operation has latency 3 (on Skylake, Broadwell, Haswell),
479    // or latency 2 (on Ivy Brdige), and throughput 1 (all architectures).
480    //
481    // We start by applying _mm256_blend_pd on adjacent rows:
482    //
483    // Step 0.0
484    // a0 b0 | a1 b1 | a2 b2 | a3 b3
485    // a0 b1 | a1 b0 | a2 b3 | a3 b2
486    // => _mm256_blend_pd with 0b1010
487    // a0 b0 | a1 b0 | a2 b2 | a3 b2 (only columns 0 and 2)
488    //
489    // Step 0.1
490    // a0 b1 | a1 b0 | a2 b3 | a3 b2 (flipped the order)
491    // a0 b0 | a1 b1 | a2 b2 | a3 b3
492    // => _mm256_blend_pd with 0b1010
493    // a0 b1 | a1 b1 | a2 b3 | a3 b3 (only columns 1 and 3)
494    //
495    // Step 0.2
496    // a0 b2 | a1 b3 | a2 b0 | a3 b1
497    // a0 b3 | a1 b2 | a2 b1 | a3 b0
498    // => _mm256_blend_pd with 0b1010
499    // a0 b2 | a1 b2 | a2 b0 | a3 b0 (only columns 0 and 2)
500    //
501    // Step 0.3
502    // a0 b3 | a1 b2 | a2 b1 | a3 b0 (flipped the order)
503    // a0 b2 | a1 b3 | a2 b0 | a3 b1
504    // => _mm256_blend_pd with 0b1010
505    // a0 b3 | a1 b3 | a2 b1 | a3 b1 (only columns 1 and 3)
506    //
507    // Step 1.0 (combining steps 0.0 and 0.2)
508    //
509    // a0 b0 | a1 b0 | a2 b2 | a3 b2
510    // a0 b2 | a1 b2 | a2 b0 | a3 b0
511    // => _mm256_permute2f128_pd with 0x30 = 0b0011_0000
512    // a0 b0 | a1 b0 | a2 b0 | a3 b0
513    //
514    // Step 1.1 (combining steps 0.0 and 0.2)
515    //
516    // a0 b0 | a1 b0 | a2 b2 | a3 b2
517    // a0 b2 | a1 b2 | a2 b0 | a3 b0
518    // => _mm256_permute2f128_pd with 0x12 = 0b0001_0010
519    // a0 b2 | a1 b2 | a2 b2 | a3 b2
520    //
521    // Step 1.2 (combining steps 0.1 and 0.3)
522    // a0 b1 | a1 b1 | a2 b3 | a3 b3
523    // a0 b3 | a1 b3 | a2 b1 | a3 b1
524    // => _mm256_permute2f128_pd with 0x30 = 0b0011_0000
525    // a0 b1 | a1 b1 | a2 b1 | a3 b1
526    //
527    // Step 1.3 (combining steps 0.1 and 0.3)
528    // a0 b1 | a1 b1 | a2 b3 | a3 b3
529    // a0 b3 | a1 b3 | a2 b1 | a3 b1
530    // => _mm256_permute2f128_pd with 0x12 = 0b0001_0010
531    // a0 b3 | a1 b3 | a2 b3 | a3 b3
532    //
533    // # b) Creating packed 4-vectors of rows
534    //
535    // To create packed 4-vectors of rows, we make use of
536    // _mm256_shuffle_pd operations followed by _mm256_permute2f128_pd.
537    //
538    // The first operation has a latency 1, throughput 1 (on architectures
539    // Skylake, Broadwell, Haswell, and Ivy Bridge).
540    //
541    // The second operation has latency 3 (on Skylake, Broadwell, Haswell),
542    // or latency 2 (on Ivy Brdige), and throughput 1 (all architectures).
543    //
544    // To achieve this, we can execute a _mm256_shuffle_pd on
545    // rows 0 and 1 stored in ab_0 and ab_1:
546    //
547    // Step 0.0
548    // a0 b0 | a1 b1 | a2 b2 | a3 b3
549    // a0 b1 | a1 b0 | a2 b3 | a3 b2
550    // => _mm256_shuffle_pd with 0000
551    // a0 b0 | a0 b1 | a2 b2 | a2 b3 (only rows 0 and 2)
552    //
553    // Step 0.1
554    // a0 b1 | a1 b0 | a2 b3 | a3 b2 (flipped the order)
555    // a0 b0 | a1 b1 | a2 b2 | a3 b3
556    // => _mm256_shuffle_pd with 1111
557    // a1 b0 | a1 b1 | a3 b2 | a3 b3 (only rows 1 and 3)
558    //
559    // Next, we perform the same operation on the other two rows:
560    //
561    // Step 0.2
562    // a0 b2 | a1 b3 | a2 b0 | a3 b1
563    // a0 b3 | a1 b2 | a2 b1 | a3 b0
564    // => _mm256_shuffle_pd with 0000
565    // a0 b2 | a0 b3 | a2 b0 | a2 b1 (only rows 0 and 2)
566    //
567    // Step 0.3
568    // a0 b3 | a1 b2 | a2 b1 | a3 b0
569    // a0 b2 | a1 b3 | a2 b0 | a3 b1
570    // => _mm256_shuffle_pd with 1111
571    // a1 b2 | a1 b3 | a3 b0 | a3 b1 (only rows 1 and 3)
572    //
573    // Next, we can apply _mm256_permute2f128_pd to select the
574    // correct columns on the matching rows:
575    //
576    // Step 1.0 (combining Steps 0.0 and 0.2):
577    // a0 b0 | a0 b1 | a2 b2 | a2 b3
578    // a0 b2 | a0 b3 | a2 b0 | a2 b1
579    // => _mm256_permute_2f128_pd with 0x20 = 0b0010_0000
580    // a0 b0 | a0 b1 | a0 b2 | a0 b3
581    //
582    // Step 1.1 (combining Steps 0.0 and 0.2):
583    // a0 b0 | a0 b1 | a2 b2 | a2 b3
584    // a0 b2 | a0 b3 | a2 b0 | a2 b1
585    // => _mm256_permute_2f128_pd with 0x03 = 0b0001_0011
586    // a2 b0 | a2 b1 | a2 b2 | a2 b3
587    //
588    // Step 1.2 (combining Steps 0.1 and 0.3):
589    // a1 b0 | a1 b1 | a3 b2 | a3 b3
590    // a1 b2 | a1 b3 | a3 b0 | a3 b1
591    // => _mm256_permute_2f128_pd with 0x20 = 0b0010_0000
592    // a1 b0 | a1 b1 | a1 b2 | a1 b3
593    //
594    // Step 1.3 (combining Steps 0.1 and 0.3):
595    // a1 b0 | a1 b1 | a3 b2 | a3 b3
596    // a1 b2 | a1 b3 | a3 b0 | a3 b1
597    // => _mm256_permute_2f128_pd with 0x03 = 0b0001_0011
598    // a3 b0 | a3 b1 | a3 b2 | a3 b3
599
600    // We use scheme a) as the default case, i.e. if c is column-major, rsc==1, or if
601    // c is of general form. Row-major c matrices, csc==1, are treated using schema b).
602    if csc == 1 {
603        // Scheme b), step 0.0
604        // a0 b0 | a1 b1 | a2 b2 | a3 b3
605        // a0 b1 | a1 b0 | a2 b3 | a3 b2
606        let a0b0_a0b1_a2b2_a2b3 = _mm256_shuffle_pd(ab[0], ab[1], 0b0000);
607
608        // Scheme b), step 0.1
609        // a0 b1 | a1 b0 | a2 b3 | a3 b2 (flipped the order)
610        // a0 b0 | a1 b1 | a2 b2 | a3 b3
611        let a1b0_a1b1_a3b2_a3b3 = _mm256_shuffle_pd(ab[1], ab[0], 0b1111);
612
613        // Scheme b), step 0.2
614        // a0 b2 | a1 b3 | a2 b0 | a3 b1
615        // a0 b3 | a1 b2 | a2 b1 | a3 b0
616        let a0b2_a0b3_a2b0_a2b1 = _mm256_shuffle_pd(ab[2], ab[3], 0b0000);
617
618        // Scheme b), step 0.3
619        // a0 b3 | a1 b2 | a2 b1 | a3 b0 (flipped the order)
620        // a0 b2 | a1 b3 | a2 b0 | a3 b1
621        let a1b2_a1b3_a3b0_a3b1 = _mm256_shuffle_pd(ab[3], ab[2], 0b1111);
622
623        let a4b0_a4b1_a6b2_a6b3 = _mm256_shuffle_pd(ab[4], ab[5], 0b0000);
624        let a5b0_a5b1_a7b2_a7b3 = _mm256_shuffle_pd(ab[5], ab[4], 0b1111);
625
626        let a4b2_a4b3_a6b0_a6b1 = _mm256_shuffle_pd(ab[6], ab[7], 0b0000);
627        let a5b2_a5b3_a7b0_a7b1 = _mm256_shuffle_pd(ab[7], ab[6], 0b1111);
628
629        // Next, we can apply _mm256_permute2f128_pd to select the
630        // correct columns on the matching rows:
631        //
632        // Step 1.0 (combining Steps 0.0 and 0.2):
633        // a0 b0 | a0 b1 | a2 b2 | a2 b3
634        // a0 b2 | a0 b3 | a2 b0 | a2 b1
635        // => _mm256_permute_2f128_pd with 0x20 = 0b0010_0000
636        // a0 b0 | a0 b1 | a0 b2 | a0 b3
637        //
638        // Step 1.1 (combining Steps 0.0 and 0.2):
639        // a0 b0 | a0 b1 | a2 b2 | a2 b3
640        // a0 b2 | a0 b3 | a2 b0 | a2 b1
641        // => _mm256_permute_2f128_pd with 0x03 = 0b0001_0011
642        // a2 b0 | a2 b1 | a2 b2 | a2 b3
643        //
644        // Step 1.2 (combining Steps 0.1 and 0.3):
645        // a1 b0 | a1 b1 | a3 b2 | a3 b3
646        // a1 b2 | a1 b3 | a3 b0 | a3 b1
647        // => _mm256_permute_2f128_pd with 0x20 = 0b0010_0000
648        // a1 b0 | a1 b1 | a1 b2 | a1 b3
649        //
650        // Step 1.3 (combining Steps 0.1 and 0.3):
651        // a1 b0 | a1 b1 | a3 b2 | a3 b3
652        // a1 b2 | a1 b3 | a3 b0 | a3 b1
653        // => _mm256_permute_2f128_pd with 0x03 = 0b0001_0011
654        // a3 b0 | a3 b1 | a3 b2 | a3 b3
655
656        // Scheme b), step 1.0
657        let a0b0_a0b1_a0b2_a0b3 = _mm256_permute2f128_pd(
658            a0b0_a0b1_a2b2_a2b3,
659            a0b2_a0b3_a2b0_a2b1,
660            0x20
661        );
662        // Scheme b), step 1.1
663        let a2b0_a2b1_a2b2_a2b3 = _mm256_permute2f128_pd(
664            a0b0_a0b1_a2b2_a2b3,
665            a0b2_a0b3_a2b0_a2b1,
666            0x13
667        );
668        // Scheme b) step 1.2
669        let a1b0_a1b1_a1b2_a1b3 = _mm256_permute2f128_pd(
670            a1b0_a1b1_a3b2_a3b3,
671            a1b2_a1b3_a3b0_a3b1,
672            0x20
673        );
674        // Scheme b) step 1.3
675        let a3b0_a3b1_a3b2_a3b3 = _mm256_permute2f128_pd(
676            a1b0_a1b1_a3b2_a3b3,
677            a1b2_a1b3_a3b0_a3b1,
678            0x13
679        );
680
681        // As above, but for ab[4..7]
682        let a4b0_a4b1_a4b2_a4b3 = _mm256_permute2f128_pd(
683            a4b0_a4b1_a6b2_a6b3,
684            a4b2_a4b3_a6b0_a6b1,
685            0x20
686        );
687
688        let a6b0_a6b1_a6b2_a6b3 = _mm256_permute2f128_pd(
689            a4b0_a4b1_a6b2_a6b3,
690            a4b2_a4b3_a6b0_a6b1,
691            0x13
692        );
693
694        let a5b0_a5b1_a5b2_a5b3 = _mm256_permute2f128_pd(
695            a5b0_a5b1_a7b2_a7b3,
696            a5b2_a5b3_a7b0_a7b1,
697            0x20
698        );
699
700        let a7b0_a7b1_a7b2_a7b3 = _mm256_permute2f128_pd(
701            a5b0_a5b1_a7b2_a7b3,
702            a5b2_a5b3_a7b0_a7b1,
703            0x13
704        );
705
706        ab[0] = a0b0_a0b1_a0b2_a0b3;
707        ab[1] = a1b0_a1b1_a1b2_a1b3;
708        ab[2] = a2b0_a2b1_a2b2_a2b3;
709        ab[3] = a3b0_a3b1_a3b2_a3b3;
710
711        ab[4] = a4b0_a4b1_a4b2_a4b3;
712        ab[5] = a5b0_a5b1_a5b2_a5b3;
713        ab[6] = a6b0_a6b1_a6b2_a6b3;
714        ab[7] = a7b0_a7b1_a7b2_a7b3;
715
716    //  rsc == 1 and general matrix orders
717    } else {
718        // Scheme a), step 0.0
719        // ab[0] = a0 b0 | a1 b1 | a2 b2 | a3 b3
720        // ab[1] = a0 b1 | a1 b0 | a2 b3 | a3 b2
721        let a0b0_a1b0_a2b2_a3b2 = _mm256_blend_pd(ab[0], ab[1], 0b1010);
722        // Scheme a), step 0.1
723        let a0b1_a1b1_a2b3_a3b3 = _mm256_blend_pd(ab[1], ab[0], 0b1010);
724
725        // Scheme a), steps 0.2
726        // ab[2] = a0 b2 | a1 b3 | a2 b0 | a3 b1
727        // ab[3] = a0 b3 | a1 b2 | a2 b1 | a3 b0
728        let a0b2_a1b2_a2b0_a3b0 = _mm256_blend_pd(ab[2], ab[3], 0b1010);
729        // Scheme a), steps 0.3
730        let a0b3_a1b3_a2b1_a3b1 = _mm256_blend_pd(ab[3], ab[2], 0b1010);
731
732        // ab[4] = a4 b0 | a5 b1 | a6 b2 | a7 b3
733        // ab[5] = a4 b1 | a5 b0 | a6 b3 | a7 b2
734        let a4b0_a5b0_a6b2_a7b2 = _mm256_blend_pd(ab[4], ab[5], 0b1010);
735        let a4b1_a5b1_a6b3_a7b3 = _mm256_blend_pd(ab[5], ab[4], 0b1010);
736
737        // ab[6] = a0 b2 | a1 b3 | a2 b0 | a3 b1
738        // ab[7] = a0 b3 | a1 b2 | a2 b1 | a3 b0
739        let a4b2_a5b2_a6b0_a7b0 = _mm256_blend_pd(ab[6], ab[7], 0b1010);
740        let a4b3_a5b3_a6b1_a7b1 = _mm256_blend_pd(ab[7], ab[6], 0b1010);
741
742        // Scheme a), step 1.0
743        let a0b0_a1b0_a2b0_a3b0 = _mm256_permute2f128_pd(
744            a0b0_a1b0_a2b2_a3b2,
745            a0b2_a1b2_a2b0_a3b0,
746            0x30
747        );
748        // Scheme a), step 1.1
749        let a0b2_a1b2_a2b2_a3b2 = _mm256_permute2f128_pd(
750            a0b0_a1b0_a2b2_a3b2,
751            a0b2_a1b2_a2b0_a3b0,
752            0x12,
753        );
754        // Scheme a) step 1.2
755        let a0b1_a1b1_a2b1_a3b1 = _mm256_permute2f128_pd(
756            a0b1_a1b1_a2b3_a3b3,
757            a0b3_a1b3_a2b1_a3b1,
758            0x30
759        );
760        // Scheme a) step 1.3
761        let a0b3_a1b3_a2b3_a3b3 = _mm256_permute2f128_pd(
762            a0b1_a1b1_a2b3_a3b3,
763            a0b3_a1b3_a2b1_a3b1,
764            0x12
765        );
766
767        // As above, but for ab[4..7]
768        let a4b0_a5b0_a6b0_a7b0 = _mm256_permute2f128_pd(
769            a4b0_a5b0_a6b2_a7b2,
770            a4b2_a5b2_a6b0_a7b0,
771            0x30
772        );
773        let a4b2_a5b2_a6b2_a7b2 = _mm256_permute2f128_pd(
774            a4b0_a5b0_a6b2_a7b2,
775            a4b2_a5b2_a6b0_a7b0,
776            0x12,
777        );
778        let a4b1_a5b1_a6b1_a7b1 = _mm256_permute2f128_pd(
779            a4b1_a5b1_a6b3_a7b3,
780            a4b3_a5b3_a6b1_a7b1,
781            0x30
782        );
783        let a4b3_a5b3_a6b3_a7b3 = _mm256_permute2f128_pd(
784            a4b1_a5b1_a6b3_a7b3,
785            a4b3_a5b3_a6b1_a7b1,
786            0x12
787        );
788
789        ab[0] = a0b0_a1b0_a2b0_a3b0;
790        ab[1] = a0b1_a1b1_a2b1_a3b1;
791        ab[2] = a0b2_a1b2_a2b2_a3b2;
792        ab[3] = a0b3_a1b3_a2b3_a3b3;
793
794        ab[4] = a4b0_a5b0_a6b0_a7b0;
795        ab[5] = a4b1_a5b1_a6b1_a7b1;
796        ab[6] = a4b2_a5b2_a6b2_a7b2;
797        ab[7] = a4b3_a5b3_a6b3_a7b3;
798    }
799
800    // Compute α (A B)
801    // Compute here if we don't have fma, else pick up α further down
802
803    let alphav = _mm256_broadcast_sd(&alpha);
804    if !MA::IS_FUSED {
805        loop_m!(i, ab[i] = _mm256_mul_pd(alphav, ab[i]));
806    }
807
808    macro_rules! c {
809        ($i:expr, $j:expr) =>
810            (c.offset(rsc * $i as isize + csc * $j as isize));
811    }
812
813    // C ← α A B + β C
814    let mut cv = [_mm256_setzero_pd(); MR];
815
816    if beta != 0. {
817        // Read C
818        if rsc == 1 {
819            loop4!(i, cv[i] = _mm256_loadu_pd(c![0, i]));
820            loop4!(i, cv[i + 4] = _mm256_loadu_pd(c![4, i]));
821        } else if csc == 1 {
822            loop4!(i, cv[i] = _mm256_loadu_pd(c![i, 0]));
823            loop4!(i, cv[i+4] = _mm256_loadu_pd(c![i+4, 0]));
824        } else {
825            loop4!(i, cv[i] = _mm256_setr_pd(
826                    *c![0, i],
827                    *c![1, i],
828                    *c![2, i],
829                    *c![3, i]
830            ));
831            loop4!(i, cv[i + 4] = _mm256_setr_pd(
832                    *c![4, i],
833                    *c![5, i],
834                    *c![6, i],
835                    *c![7, i]
836            ));
837        }
838        // Compute β C
839        // _mm256_set1_pd and _mm256_broadcast_sd seem to achieve the same thing.
840        let beta_v = _mm256_broadcast_sd(&beta);
841        loop_m!(i, cv[i] = _mm256_mul_pd(cv[i], beta_v));
842    }
843
844    // Compute (α A B) + (β C)
845    if !MA::IS_FUSED {
846        loop_m!(i, cv[i] = _mm256_add_pd(cv[i], ab[i]));
847    } else {
848        loop_m!(i, cv[i] = MA::multiply_add(alphav, ab[i], cv[i]));
849    }
850
851    if rsc == 1 {
852        loop4!(i, _mm256_storeu_pd(c![0, i], cv[i]));
853        loop4!(i, _mm256_storeu_pd(c![4, i], cv[i + 4]));
854    } else if csc == 1 {
855        loop4!(i, _mm256_storeu_pd(c![i, 0], cv[i]));
856        loop4!(i, _mm256_storeu_pd(c![i+4, 0], cv[i + 4]));
857    } else {
858        // Permute to bring each element in the vector to the front and store
859        loop4!(i, {
860            // E.g. c_0_lo = a0b0 | a1b0
861            let c_lo: __m128d = _mm256_extractf128_pd(cv[i], 0);
862            // E.g. c_0_hi = a2b0 | a3b0
863            let c_hi: __m128d = _mm256_extractf128_pd(cv[i], 1);
864
865            _mm_storel_pd(c![0, i], c_lo);
866            _mm_storeh_pd(c![1, i], c_lo);
867            _mm_storel_pd(c![2, i], c_hi);
868            _mm_storeh_pd(c![3, i], c_hi);
869
870            // E.g. c_0_lo = a0b0 | a1b0
871            let c_lo: __m128d = _mm256_extractf128_pd(cv[i+4], 0);
872            // E.g. c_0_hi = a2b0 | a3b0
873            let c_hi: __m128d = _mm256_extractf128_pd(cv[i+4], 1);
874
875            _mm_storel_pd(c![4, i], c_lo);
876            _mm_storeh_pd(c![5, i], c_lo);
877            _mm_storel_pd(c![6, i], c_hi);
878            _mm_storeh_pd(c![7, i], c_hi);
879        });
880    }
881}
882
883#[cfg(target_arch="aarch64")]
884#[cfg(has_aarch64_simd)]
885#[target_feature(enable="neon")]
886unsafe fn kernel_target_neon(k: usize, alpha: T, a: *const T, b: *const T,
887                             beta: T, c: *mut T, rsc: isize, csc: isize)
888{
889    use core::arch::aarch64::*;
890    const MR: usize = KernelNeon::MR;
891    const NR: usize = KernelNeon::NR;
892
893    let (mut a, mut b) = (a, b);
894
895    // Kernel 8 x 4 (a x b)
896    // Four quadrants of 4 x 2
897    let mut ab11 = [vmovq_n_f64(0.); 4];
898    let mut ab12 = [vmovq_n_f64(0.); 4];
899    let mut ab21 = [vmovq_n_f64(0.); 4];
900    let mut ab22 = [vmovq_n_f64(0.); 4];
901
902    // Compute
903    // ab_ij = a_i * b_j for all i, j
904    macro_rules! ab_ij_equals_ai_bj_12 {
905        ($dest:ident, $av:expr, $bv:expr) => {
906            $dest[0] = vfmaq_laneq_f64($dest[0], $bv, $av, 0);
907            $dest[1] = vfmaq_laneq_f64($dest[1], $bv, $av, 1);
908        }
909    }
910
911    macro_rules! ab_ij_equals_ai_bj_23 {
912        ($dest:ident, $av:expr, $bv:expr) => {
913            $dest[2] = vfmaq_laneq_f64($dest[2], $bv, $av, 0);
914            $dest[3] = vfmaq_laneq_f64($dest[3], $bv, $av, 1);
915        }
916    }
917
918    for _ in 0..k {
919        let b1 = vld1q_f64(b);
920        let b2 = vld1q_f64(b.add(2));
921
922        let a1 = vld1q_f64(a);
923        let a2 = vld1q_f64(a.add(2));
924
925        ab_ij_equals_ai_bj_12!(ab11, a1, b1);
926        ab_ij_equals_ai_bj_23!(ab11, a2, b1);
927        ab_ij_equals_ai_bj_12!(ab12, a1, b2);
928        ab_ij_equals_ai_bj_23!(ab12, a2, b2);
929
930        let a3 = vld1q_f64(a.add(4));
931        let a4 = vld1q_f64(a.add(6));
932
933        ab_ij_equals_ai_bj_12!(ab21, a3, b1);
934        ab_ij_equals_ai_bj_23!(ab21, a4, b1);
935        ab_ij_equals_ai_bj_12!(ab22, a3, b2);
936        ab_ij_equals_ai_bj_23!(ab22, a4, b2);
937
938        a = a.add(MR);
939        b = b.add(NR);
940    }
941
942    macro_rules! c {
943        ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
944    }
945
946    // ab *= alpha
947    loop4!(i, ab11[i] = vmulq_n_f64(ab11[i], alpha));
948    loop4!(i, ab12[i] = vmulq_n_f64(ab12[i], alpha));
949    loop4!(i, ab21[i] = vmulq_n_f64(ab21[i], alpha));
950    loop4!(i, ab22[i] = vmulq_n_f64(ab22[i], alpha));
951
952    // load one float64x2_t from two pointers
953    macro_rules! loadq_from_pointers {
954        ($p0:expr, $p1:expr) => (
955            {
956                let v = vld1q_dup_f64($p0);
957                let v = vld1q_lane_f64($p1, v, 1);
958                v
959            }
960        );
961    }
962
963    if beta != 0. {
964        // load existing value in C
965        let mut c11 = [vmovq_n_f64(0.); 4];
966        let mut c12 = [vmovq_n_f64(0.); 4];
967        let mut c21 = [vmovq_n_f64(0.); 4];
968        let mut c22 = [vmovq_n_f64(0.); 4];
969
970        if csc == 1 {
971            loop4!(i, c11[i] = vld1q_f64(c![i + 0, 0]));
972            loop4!(i, c12[i] = vld1q_f64(c![i + 0, 2]));
973            loop4!(i, c21[i] = vld1q_f64(c![i + 4, 0]));
974            loop4!(i, c22[i] = vld1q_f64(c![i + 4, 2]));
975        } else {
976            loop4!(i, c11[i] = loadq_from_pointers!(c![i + 0, 0], c![i + 0, 1]));
977            loop4!(i, c12[i] = loadq_from_pointers!(c![i + 0, 2], c![i + 0, 3]));
978            loop4!(i, c21[i] = loadq_from_pointers!(c![i + 4, 0], c![i + 4, 1]));
979            loop4!(i, c22[i] = loadq_from_pointers!(c![i + 4, 2], c![i + 4, 3]));
980        }
981
982        let betav = vmovq_n_f64(beta);
983
984        // ab += β C
985        loop4!(i, ab11[i] = vfmaq_f64(ab11[i], c11[i], betav));
986        loop4!(i, ab12[i] = vfmaq_f64(ab12[i], c12[i], betav));
987        loop4!(i, ab21[i] = vfmaq_f64(ab21[i], c21[i], betav));
988        loop4!(i, ab22[i] = vfmaq_f64(ab22[i], c22[i], betav));
989    }
990
991    // c <- ab
992    // which is in full
993    //   C <- α A B (+ β C)
994    if csc == 1 {
995        loop4!(i, vst1q_f64(c![i + 0, 0], ab11[i]));
996        loop4!(i, vst1q_f64(c![i + 0, 2], ab12[i]));
997        loop4!(i, vst1q_f64(c![i + 4, 0], ab21[i]));
998        loop4!(i, vst1q_f64(c![i + 4, 2], ab22[i]));
999    } else {
1000        loop4!(i, vst1q_lane_f64(c![i + 0, 0], ab11[i], 0));
1001        loop4!(i, vst1q_lane_f64(c![i + 0, 1], ab11[i], 1));
1002
1003        loop4!(i, vst1q_lane_f64(c![i + 0, 2], ab12[i], 0));
1004        loop4!(i, vst1q_lane_f64(c![i + 0, 3], ab12[i], 1));
1005
1006        loop4!(i, vst1q_lane_f64(c![i + 4, 0], ab21[i], 0));
1007        loop4!(i, vst1q_lane_f64(c![i + 4, 1], ab21[i], 1));
1008
1009        loop4!(i, vst1q_lane_f64(c![i + 4, 2], ab22[i], 0));
1010        loop4!(i, vst1q_lane_f64(c![i + 4, 3], ab22[i], 1));
1011    }
1012}
1013
1014#[inline]
1015unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
1016                                   beta: T, c: *mut T, rsc: isize, csc: isize)
1017{
1018    const MR: usize = KernelFallback::MR;
1019    const NR: usize = KernelFallback::NR;
1020    let mut ab: [[T; NR]; MR] = [[0.; NR]; MR];
1021    let mut a = a;
1022    let mut b = b;
1023    debug_assert_eq!(beta, 0., "Beta must be 0 or is not masked");
1024
1025    // Compute matrix multiplication into ab[i][j]
1026    unroll_by!(4 => k, {
1027        loop4!(i, loop4!(j, ab[i][j] += at(a, i) * at(b, j)));
1028
1029        a = a.offset(MR as isize);
1030        b = b.offset(NR as isize);
1031    });
1032
1033    macro_rules! c {
1034        ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
1035    }
1036
1037    // set C = α A B
1038    loop4!(j, loop4!(i, *c![i, j] = alpha * ab[i][j]));
1039}
1040
1041#[inline(always)]
1042unsafe fn at(ptr: *const T, i: usize) -> T {
1043    *ptr.offset(i as isize)
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048    use super::*;
1049    use crate::kernel::test::test_a_kernel;
1050
1051    #[test]
1052    fn test_kernel_fallback_impl() {
1053        test_a_kernel::<KernelFallback, _>("kernel");
1054    }
1055
1056    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
1057    #[test]
1058    fn test_loop_m_n() {
1059        let mut m = [[0; 4]; KernelAvx::MR];
1060        loop_m!(i, loop4!(j, m[i][j] += 1));
1061        for arr in &m[..] {
1062            for elt in &arr[..] {
1063                assert_eq!(*elt, 1);
1064            }
1065        }
1066    }
1067
1068    #[cfg(any(target_arch="aarch64"))]
1069    #[cfg(has_aarch64_simd)]
1070    mod test_kernel_aarch64 {
1071        use super::test_a_kernel;
1072        use super::super::*;
1073        #[cfg(feature = "std")]
1074        use std::println;
1075
1076        macro_rules! test_arch_kernels_aarch64 {
1077            ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
1078                $(
1079                #[test]
1080                fn $name() {
1081                    if is_aarch64_feature_detected_!($feature_name) {
1082                        test_a_kernel::<$kernel_ty, _>(stringify!($name));
1083                    } else {
1084                        #[cfg(feature = "std")]
1085                        println!("Skipping, host does not have feature: {:?}", $feature_name);
1086                    }
1087                }
1088                )*
1089            }
1090        }
1091
1092        test_arch_kernels_aarch64! {
1093            "neon", neon, KernelNeon
1094        }
1095    }
1096
1097    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
1098    mod test_kernel_x86 {
1099        use super::test_a_kernel;
1100        use super::super::*;
1101        #[cfg(feature = "std")]
1102        use std::println;
1103        macro_rules! test_arch_kernels_x86 {
1104            ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
1105                $(
1106                #[test]
1107                fn $name() {
1108                    if is_x86_feature_detected_!($feature_name) {
1109                        test_a_kernel::<$kernel_ty, _>(stringify!($name));
1110                    } else {
1111                        #[cfg(feature = "std")]
1112                        println!("Skipping, host does not have feature: {:?}", $feature_name);
1113                    }
1114                }
1115                )*
1116            }
1117        }
1118
1119        test_arch_kernels_x86! {
1120            "fma", fma, KernelFma,
1121            "avx", avx, KernelAvx,
1122            "sse2", sse2, KernelSse2
1123        }
1124    }
1125}