1use crate::archparam;
10use crate::packing::pack;
11
12pub(crate) trait GemmKernel {
14 type Elem: Element;
15
16 const MR: usize = Self::MRTy::VALUE;
18 const NR: usize = Self::NRTy::VALUE;
20 type MRTy: ConstNum;
22 type NRTy: ConstNum;
24
25 fn align_to() -> usize;
27
28 fn always_masked() -> bool;
30
31 #[inline(always)]
33 fn nc() -> usize { archparam::S_NC }
34 #[inline(always)]
35 fn kc() -> usize { archparam::S_KC }
36 #[inline(always)]
37 fn mc() -> usize { archparam::S_MC }
38
39 #[inline]
46 unsafe fn pack_mr(kc: usize, mc: usize, pack_buf: &mut [Self::Elem],
47 a: *const Self::Elem, rsa: isize, csa: isize)
48 {
49 pack::<Self::MRTy, _>(kc, mc, pack_buf, a, rsa, csa)
50 }
51
52 #[inline]
59 unsafe fn pack_nr(kc: usize, mc: usize, pack_buf: &mut [Self::Elem],
60 a: *const Self::Elem, rsa: isize, csa: isize)
61 {
62 pack::<Self::NRTy, _>(kc, mc, pack_buf, a, rsa, csa)
63 }
64
65
66 unsafe fn kernel(
86 k: usize,
87 alpha: Self::Elem,
88 a: *const Self::Elem,
89 b: *const Self::Elem,
90 beta: Self::Elem,
91 c: *mut Self::Elem, rsc: isize, csc: isize);
92}
93
94pub(crate) trait Element : Copy + Send + Sync {
95 fn zero() -> Self;
96 fn one() -> Self;
97 #[cfg_attr(not(test), allow(unused))]
98 fn test_value() -> Self;
99 fn is_zero(&self) -> bool;
100 fn add_assign(&mut self, rhs: Self);
101 fn mul_assign(&mut self, rhs: Self);
102}
103
104impl Element for f32 {
105 fn zero() -> Self { 0. }
106 fn one() -> Self { 1. }
107 fn test_value() -> Self { 1. }
108 fn is_zero(&self) -> bool { *self == 0. }
109 fn add_assign(&mut self, rhs: Self) { *self += rhs; }
110 fn mul_assign(&mut self, rhs: Self) { *self *= rhs; }
111}
112
113impl Element for f64 {
114 fn zero() -> Self { 0. }
115 fn one() -> Self { 1. }
116 fn test_value() -> Self { 1. }
117 fn is_zero(&self) -> bool { *self == 0. }
118 fn add_assign(&mut self, rhs: Self) { *self += rhs; }
119 fn mul_assign(&mut self, rhs: Self) { *self *= rhs; }
120}
121
122pub(crate) trait GemmSelect<T> {
124 fn select<K>(self, kernel: K)
126 where K: GemmKernel<Elem=T>,
127 T: Element;
128}
129
130#[cfg(feature = "cgemm")]
131#[allow(non_camel_case_types)]
132pub(crate) type c32 = [f32; 2];
133
134#[cfg(feature = "cgemm")]
135#[allow(non_camel_case_types)]
136pub(crate) type c64 = [f64; 2];
137
138#[cfg(feature = "cgemm")]
139impl Element for c32 {
140 fn zero() -> Self { [0., 0.] }
141 fn one() -> Self { [1., 0.] }
142 fn test_value() -> Self { [2., 1.] }
143 fn is_zero(&self) -> bool { *self == [0., 0.] }
144
145 #[inline(always)]
146 fn add_assign(&mut self, y: Self) {
147 self[0] += y[0];
148 self[1] += y[1];
149 }
150
151 #[inline(always)]
152 fn mul_assign(&mut self, rhs: Self) {
153 *self = c32_mul(*self, rhs);
154 }
155}
156
157#[cfg(feature = "cgemm")]
158impl Element for c64 {
159 fn zero() -> Self { [0., 0.] }
160 fn one() -> Self { [1., 0.] }
161 fn test_value() -> Self { [2., 1.] }
162 fn is_zero(&self) -> bool { *self == [0., 0.] }
163
164 #[inline(always)]
165 fn add_assign(&mut self, y: Self) {
166 self[0] += y[0];
167 self[1] += y[1];
168 }
169
170 #[inline(always)]
171 fn mul_assign(&mut self, rhs: Self) {
172 *self = c64_mul(*self, rhs);
173 }
174}
175
176#[cfg(feature = "cgemm")]
177#[inline(always)]
178pub(crate) fn c32_mul(x: c32, y: c32) -> c32 {
179 let [a, b] = x;
180 let [c, d] = y;
181 [a * c - b * d, b * c + a * d]
182}
183
184#[cfg(feature = "cgemm")]
185#[inline(always)]
186pub(crate) fn c64_mul(x: c64, y: c64) -> c64 {
187 let [a, b] = x;
188 let [c, d] = y;
189 [a * c - b * d, b * c + a * d]
190}
191
192
193pub(crate) trait ConstNum {
194 const VALUE: usize;
195}
196
197#[cfg(feature = "cgemm")]
198pub(crate) struct U2;
199pub(crate) struct U4;
200pub(crate) struct U8;
201
202#[cfg(feature = "cgemm")]
203impl ConstNum for U2 { const VALUE: usize = 2; }
204impl ConstNum for U4 { const VALUE: usize = 4; }
205impl ConstNum for U8 { const VALUE: usize = 8; }
206
207
208#[cfg(test)]
209pub(crate) mod test {
210 use std::fmt;
211
212 use super::GemmKernel;
213 use super::Element;
214 use crate::aligned_alloc::Alloc;
215
216 pub(crate) fn aligned_alloc<K>(elt: K::Elem, n: usize) -> Alloc<K::Elem>
217 where K: GemmKernel,
218 K::Elem: Copy,
219 {
220 unsafe {
221 Alloc::new(n, K::align_to()).init_with(elt)
222 }
223 }
224
225 pub(crate) fn test_a_kernel<K, T>(_name: &str)
230 where
231 K: GemmKernel<Elem = T>,
232 T: Element + fmt::Debug + PartialEq,
233 {
234 const K: usize = 16;
235 let mr = K::MR;
236 let nr = K::NR;
237
238 let mut a = aligned_alloc::<K>(T::zero(), mr * K);
242 let mut b = aligned_alloc::<K>(T::zero(), nr * K);
243
244 let mut count = 1;
245 for i in 0..mr {
246 for j in 0..K {
247 for _ in 0..count {
248 a[i * K + j].add_assign(T::test_value());
249 }
250 count += 1;
251 }
252 }
253
254 for i in 0..Ord::min(K, nr) {
255 b[i + i * nr] = T::one();
256 }
257
258 let mut c = vec![T::zero(); mr * nr];
259 unsafe {
260 K::kernel(K, T::one(), a.as_ptr(), b.as_ptr(), T::zero(), c.as_mut_ptr(), 1, mr as isize);
262 }
263 let common_len = Ord::min(a.len(), c.len());
264 assert_eq!(&a[..common_len], &c[..common_len]);
265
266 let mut a = aligned_alloc::<K>(T::zero(), mr * K);
270 let mut b = aligned_alloc::<K>(T::zero(), nr * K);
271
272 for i in 0..Ord::min(K, mr) {
273 a[i + i * mr] = T::one();
274 }
275
276 let mut count = 1;
277 for i in 0..K {
278 for j in 0..nr {
279 for _ in 0..count {
280 b[i * nr + j].add_assign(T::test_value());
281 }
282 count += 1;
283 }
284 }
285
286 let mut c = vec![T::zero(); mr * nr];
287 unsafe {
288 K::kernel(K, T::one(), a.as_ptr(), b.as_ptr(), T::zero(), c.as_mut_ptr(), nr as isize, 1);
290 }
291 let common_len = Ord::min(b.len(), c.len());
292 assert_eq!(&b[..common_len], &c[..common_len]);
293 }
294
295 #[cfg(feature="cgemm")]
296 pub(crate) fn test_complex_packed_kernel<K, T, TReal>(_name: &str)
301 where
302 K: GemmKernel<Elem = T>,
303 T: Element + fmt::Debug + PartialEq,
304 TReal: Element + fmt::Debug + PartialEq,
305 {
306 use crate::cgemm_common::pack_complex;
307
308 const K: usize = 16;
309 let mr = K::MR;
310 let nr = K::NR;
311
312 let mut a = aligned_alloc::<K>(T::zero(), mr * K);
316 let mut apack = aligned_alloc::<K>(T::zero(), mr * K);
317 let mut b = aligned_alloc::<K>(T::zero(), nr * K);
318 let mut bpack = aligned_alloc::<K>(T::zero(), nr * K);
319
320 let mut count = 1;
321 for i in 0..mr {
322 for j in 0..K {
323 for _ in 0..count {
324 a[i * K + j].add_assign(T::test_value());
325 }
326 count += 1;
327 }
328 }
329
330 for i in 0..Ord::min(K, nr) {
331 b[i + i * nr] = T::one();
332 }
333
334 unsafe {
336 pack_complex::<K::MRTy, T, TReal>(K, mr, &mut apack[..], a.ptr_mut(), 1, mr as isize);
337 pack_complex::<K::NRTy, T, TReal>(nr, K, &mut bpack[..], b.ptr_mut(), nr as isize, 1);
338 }
339
340 let mut c = vec![T::zero(); mr * nr];
341 unsafe {
342 K::kernel(K, T::one(), apack.as_ptr(), bpack.as_ptr(), T::zero(), c.as_mut_ptr(), 1, mr as isize);
344 }
345 let common_len = Ord::min(a.len(), c.len());
346 assert_eq!(&a[..common_len], &c[..common_len]);
347 }
348
349}