matrixmultiply/packing.rs
1// Copyright 2016 - 2023 Ulrik Sverdrup "bluss"
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use rawpointer::PointerExt;
10
11use core::ptr::copy_nonoverlapping;
12
13use crate::kernel::ConstNum;
14use crate::kernel::Element;
15
16/// Pack matrix into `pack`
17///
18/// + kc: length of the micropanel
19/// + mc: number of rows/columns in the matrix to be packed
20/// + pack: packing buffer
21/// + a: matrix,
22/// + rsa: row stride
23/// + csa: column stride
24///
25/// + MR: kernel rows/columns that we round up to
26// If one of pack and a is of a reference type, it gets a noalias annotation which
27// gives benefits to optimization. The packing buffer is contiguous so it can be passed as a slice
28// here.
29pub(crate) unsafe fn pack<MR, T>(kc: usize, mc: usize, pack: &mut [T],
30 a: *const T, rsa: isize, csa: isize)
31 where T: Element,
32 MR: ConstNum,
33{
34 pack_impl::<MR, T>(kc, mc, pack, a, rsa, csa)
35}
36
37/// Specialized for AVX2
38/// Safety: Requires AVX2
39#[cfg(any(target_arch="x86", target_arch="x86_64"))]
40#[target_feature(enable="avx2")]
41pub(crate) unsafe fn pack_avx2<MR, T>(kc: usize, mc: usize, pack: &mut [T],
42 a: *const T, rsa: isize, csa: isize)
43 where T: Element,
44 MR: ConstNum,
45{
46 pack_impl::<MR, T>(kc, mc, pack, a, rsa, csa)
47}
48
49/// Pack implementation, see pack above for docs.
50///
51/// Uses inline(always) so that it can be instantiated for different target features.
52#[inline(always)]
53unsafe fn pack_impl<MR, T>(kc: usize, mc: usize, pack: &mut [T],
54 a: *const T, rsa: isize, csa: isize)
55 where T: Element,
56 MR: ConstNum,
57{
58 let pack = pack.as_mut_ptr();
59 let mr = MR::VALUE;
60 let mut p = 0; // offset into pack
61
62 if rsa == 1 {
63 // if the matrix is contiguous in the same direction we are packing,
64 // copy a kernel row at a time.
65 for ir in 0..mc/mr {
66 let row_offset = ir * mr;
67 for j in 0..kc {
68 let a_row = a.stride_offset(rsa, row_offset)
69 .stride_offset(csa, j);
70 copy_nonoverlapping(a_row, pack.add(p), mr);
71 p += mr;
72 }
73 }
74 } else {
75 // general layout case
76 for ir in 0..mc/mr {
77 let row_offset = ir * mr;
78 for j in 0..kc {
79 for i in 0..mr {
80 let a_elt = a.stride_offset(rsa, i + row_offset)
81 .stride_offset(csa, j);
82 copy_nonoverlapping(a_elt, pack.add(p), 1);
83 p += 1;
84 }
85 }
86 }
87 }
88
89 let zero = <_>::zero();
90
91 // Pad with zeros to multiple of kernel size (uneven mc)
92 let rest = mc % mr;
93 if rest > 0 {
94 let row_offset = (mc/mr) * mr;
95 for j in 0..kc {
96 for i in 0..mr {
97 if i < rest {
98 let a_elt = a.stride_offset(rsa, i + row_offset)
99 .stride_offset(csa, j);
100 copy_nonoverlapping(a_elt, pack.add(p), 1);
101 } else {
102 *pack.add(p) = zero;
103 }
104 p += 1;
105 }
106 }
107 }
108}
109