1#[cfg(feature="std")]
10use core::cell::UnsafeCell;
11use core::cmp::min;
12use core::mem::size_of;
13use core::slice;
14
15use crate::aligned_alloc::Alloc;
16
17use crate::ptr::Ptr;
18use crate::util::range_chunk;
19use crate::util::round_up_to;
20
21use crate::kernel::Element;
22use crate::kernel::GemmKernel;
23use crate::kernel::GemmSelect;
24#[cfg(feature = "cgemm")]
25use crate::kernel::{c32, c64};
26use crate::threading::{get_thread_pool, ThreadPoolCtx, LoopThreadConfig};
27use crate::sgemm_kernel;
28use crate::dgemm_kernel;
29#[cfg(feature = "cgemm")]
30use crate::cgemm_kernel;
31#[cfg(feature = "cgemm")]
32use crate::zgemm_kernel;
33use rawpointer::PointerExt;
34
35pub unsafe fn sgemm(
52 m: usize, k: usize, n: usize,
53 alpha: f32,
54 a: *const f32, rsa: isize, csa: isize,
55 b: *const f32, rsb: isize, csb: isize,
56 beta: f32,
57 c: *mut f32, rsc: isize, csc: isize)
58{
59 sgemm_kernel::detect(GemmParameters { m, k, n,
60 alpha,
61 a, rsa, csa,
62 b, rsb, csb,
63 beta,
64 c, rsc, csc})
65}
66
67pub unsafe fn dgemm(
84 m: usize, k: usize, n: usize,
85 alpha: f64,
86 a: *const f64, rsa: isize, csa: isize,
87 b: *const f64, rsb: isize, csb: isize,
88 beta: f64,
89 c: *mut f64, rsc: isize, csc: isize)
90{
91 dgemm_kernel::detect(GemmParameters { m, k, n,
92 alpha,
93 a, rsa, csa,
94 b, rsb, csb,
95 beta,
96 c, rsc, csc})
97}
98
99#[cfg(feature = "cgemm")]
103#[non_exhaustive]
104#[derive(Copy, Clone, Debug)]
105pub enum CGemmOption {
106 Standard,
108}
109
110#[cfg(feature = "cgemm")]
111pub unsafe fn cgemm(
130 flaga: CGemmOption, flagb: CGemmOption,
131 m: usize, k: usize, n: usize,
132 alpha: c32,
133 a: *const c32, rsa: isize, csa: isize,
134 b: *const c32, rsb: isize, csb: isize,
135 beta: c32,
136 c: *mut c32, rsc: isize, csc: isize)
137{
138 let _ = (flaga, flagb);
139 cgemm_kernel::detect(GemmParameters { m, k, n,
140 alpha,
141 a, rsa, csa,
142 b, rsb, csb,
143 beta,
144 c, rsc, csc})
145}
146
147#[cfg(feature = "cgemm")]
148pub unsafe fn zgemm(
167 flaga: CGemmOption, flagb: CGemmOption,
168 m: usize, k: usize, n: usize,
169 alpha: c64,
170 a: *const c64, rsa: isize, csa: isize,
171 b: *const c64, rsb: isize, csb: isize,
172 beta: c64,
173 c: *mut c64, rsc: isize, csc: isize)
174{
175 let _ = (flaga, flagb);
176 zgemm_kernel::detect(GemmParameters { m, k, n,
177 alpha,
178 a, rsa, csa,
179 b, rsb, csb,
180 beta,
181 c, rsc, csc})
182}
183
184struct GemmParameters<T> {
185 m: usize, k: usize, n: usize,
187 alpha: T,
188 a: *const T, rsa: isize, csa: isize,
189 beta: T,
190 b: *const T, rsb: isize, csb: isize,
191 c: *mut T, rsc: isize, csc: isize,
192}
193
194impl<T> GemmSelect<T> for GemmParameters<T> {
195 fn select<K>(self, _kernel: K)
196 where K: GemmKernel<Elem=T>,
197 T: Element,
198 {
199 let GemmParameters {
203 m, k, n,
204 alpha,
205 a, rsa, csa,
206 b, rsb, csb,
207 beta,
208 c, rsc, csc} = self;
209
210 unsafe {
211 gemm_loop::<K>(
212 m, k, n,
213 alpha,
214 a, rsa, csa,
215 b, rsb, csb,
216 beta,
217 c, rsc, csc)
218 }
219 }
220}
221
222
223#[inline(always)]
228fn ensure_kernel_params<K>()
229 where K: GemmKernel
230{
231 let mr = K::MR;
232 let nr = K::NR;
233 assert!(mr > 0 && mr <= 8);
236 assert!(nr > 0 && nr <= 8);
237 assert!(mr * nr * size_of::<K::Elem>() <= 8 * 4 * 8);
238 assert!(K::align_to() <= 32);
239 let max_align = size_of::<K::Elem>() * min(mr, nr);
241 assert!(K::align_to() <= max_align);
242
243 assert!(K::MR <= K::mc());
244 assert!(K::mc() <= K::kc());
245 assert!(K::kc() <= K::nc());
246 assert!(K::nc() <= 65536);
247}
248
249#[inline(never)]
254unsafe fn gemm_loop<K>(
255 m: usize, k: usize, n: usize,
256 alpha: K::Elem,
257 a: *const K::Elem, rsa: isize, csa: isize,
258 b: *const K::Elem, rsb: isize, csb: isize,
259 beta: K::Elem,
260 c: *mut K::Elem, rsc: isize, csc: isize)
261 where K: GemmKernel
262{
263 debug_assert!(m <= 1 || n == 0 || rsc != 0);
264 debug_assert!(m == 0 || n <= 1 || csc != 0);
265
266 if m == 0 || k == 0 || n == 0 {
268 return c_to_beta_c(m, n, beta, c, rsc, csc);
269 }
270
271 let knc = K::nc();
272 let kkc = K::kc();
273 let kmc = K::mc();
274 ensure_kernel_params::<K>();
275
276 let a = Ptr(a);
277 let b = Ptr(b);
278 let c = Ptr(c);
279
280 let (nthreads, tp) = get_thread_pool();
281 let thread_config = LoopThreadConfig::new::<K>(m, k, n, nthreads);
282 let nap = thread_config.num_pack_a();
283
284 let (mut packing_buffer, ap_size, bp_size) = make_packing_buffer::<K>(m, k, n, nap);
285 let app = Ptr(packing_buffer.ptr_mut());
286 let bpp = app.add(ap_size * nap);
287
288 for (l5, nc) in range_chunk(n, knc) {
290 dprint!("LOOP 5, {}, nc={}", l5, nc);
291 let b = b.stride_offset(csb, knc * l5);
292 let c = c.stride_offset(csc, knc * l5);
293
294 for (l4, kc) in range_chunk(k, kkc) {
298 dprint!("LOOP 4, {}, kc={}", l4, kc);
299 let b = b.stride_offset(rsb, kkc * l4);
300 let a = a.stride_offset(csa, kkc * l4);
301
302 K::pack_nr(kc, nc, slice::from_raw_parts_mut(bpp.ptr(), bp_size),
304 b.ptr(), csb, rsb);
305
306 let betap = if l4 == 0 { beta } else { <_>::one() };
308
309 range_chunk(m, kmc)
311 .parallel(thread_config.loop3, tp)
312 .thread_local(move |i, _nt| {
313 debug_assert!(i < nap);
315 app.add(ap_size * i)
316 })
317 .for_each(move |tp, &mut app, l3, mc| {
318 dprint!("LOOP 3, {}, mc={}", l3, mc);
319 let a = a.stride_offset(rsa, kmc * l3);
320 let c = c.stride_offset(rsc, kmc * l3);
321
322 K::pack_mr(kc, mc, slice::from_raw_parts_mut(app.ptr(), ap_size),
324 a.ptr(), rsa, csa);
325
326 gemm_packed::<K>(nc, kc, mc,
328 alpha,
329 app.to_const(), bpp.to_const(),
330 betap,
331 c, rsc, csc,
332 tp, thread_config);
333 });
334 }
335 }
336}
337
338const KERNEL_MAX_SIZE: usize = 8 * 8 * 4;
340const KERNEL_MAX_ALIGN: usize = 32;
341const MASK_BUF_SIZE: usize = KERNEL_MAX_SIZE + KERNEL_MAX_ALIGN - 1;
342
343#[cfg_attr(
349 not(any(
350 target_os = "macos",
351 all(
353 target_arch = "x86",
354 target_vendor = "win7",
355 target_os = "windows",
356 target_env = "msvc"
357 )
358 )),
359 repr(align(16))
360)]
361struct MaskBuffer {
362 buffer: [u8; MASK_BUF_SIZE],
363}
364
365#[cfg(feature = "std")]
368thread_local! {
369 static MASK_BUF: UnsafeCell<MaskBuffer> =
370 UnsafeCell::new(MaskBuffer { buffer: [0; MASK_BUF_SIZE] });
371}
372
373unsafe fn gemm_packed<K>(nc: usize, kc: usize, mc: usize,
381 alpha: K::Elem,
382 app: Ptr<*const K::Elem>, bpp: Ptr<*const K::Elem>,
383 beta: K::Elem,
384 c: Ptr<*mut K::Elem>, rsc: isize, csc: isize,
385 tp: ThreadPoolCtx, thread_config: LoopThreadConfig)
386 where K: GemmKernel,
387{
388 let mr = K::MR;
389 let nr = K::NR;
390 assert!(mr * nr * size_of::<K::Elem>() <= KERNEL_MAX_SIZE && K::align_to() <= KERNEL_MAX_ALIGN);
392
393 #[cfg(not(feature = "std"))]
394 let mut mask_buf = MaskBuffer { buffer: [0; MASK_BUF_SIZE] };
395
396 range_chunk(nc, nr)
398 .parallel(thread_config.loop2, tp)
399 .thread_local(|_i, _nt| {
400 let mut ptr;
401 #[cfg(not(feature = "std"))]
402 {
403 debug_assert_eq!(_nt, 1);
404 ptr = mask_buf.buffer.as_mut_ptr();
405 }
406 #[cfg(feature = "std")]
407 {
408 ptr = MASK_BUF.with(|buf| (*buf.get()).buffer.as_mut_ptr());
409 }
410 ptr = align_ptr(K::align_to(), ptr);
411 slice::from_raw_parts_mut(ptr as *mut K::Elem, KERNEL_MAX_SIZE / size_of::<K::Elem>())
412 })
413 .for_each(move |_tp, mask_buf, l2, nr_| {
414 let bpp = bpp.stride_offset(1, kc * nr * l2);
415 let c = c.stride_offset(csc, nr * l2);
416
417 for (l1, mr_) in range_chunk(mc, mr) {
419 let app = app.stride_offset(1, kc * mr * l1);
420 let c = c.stride_offset(rsc, mr * l1);
421
422 if K::always_masked() || nr_ < nr || mr_ < mr {
426 masked_kernel::<_, K>(kc, alpha, app.ptr(), bpp.ptr(),
427 beta, c.ptr(), rsc, csc,
428 mr_, nr_, mask_buf);
429 continue;
430 } else {
431 K::kernel(kc, alpha, app.ptr(), bpp.ptr(), beta, c.ptr(), rsc, csc);
432 }
433 }
434 });
435}
436
437unsafe fn make_packing_buffer<K>(m: usize, k: usize, n: usize, na: usize)
448 -> (Alloc<K::Elem>, usize, usize)
449 where K: GemmKernel,
450{
451 let m = min(m, K::mc());
454 let k = min(k, K::kc());
455 let n = min(n, K::nc());
456 debug_assert_ne!(na, 0);
459 debug_assert!(na <= 128);
460 let apack_size = k * round_up_to(m, K::MR);
461 let bpack_size = k * round_up_to(n, K::NR);
462 let nelem = apack_size * na + bpack_size;
463
464 dprint!("packed nelem={}, apack={}, bpack={},
465 m={} k={} n={}, na={}",
466 nelem, apack_size, bpack_size,
467 m,k,n, na);
468
469 (Alloc::new(nelem, K::align_to()), apack_size, bpack_size)
470}
471
472#[inline]
475unsafe fn align_ptr<T>(mut align_to: usize, mut ptr: *mut T) -> *mut T {
476 if cfg!(target_os = "macos") {
478 align_to = Ord::max(align_to, 8);
479 }
480
481 if align_to != 0 {
482 let cur_align = ptr as usize % align_to;
483 if cur_align != 0 {
484 ptr = ptr.offset(((align_to - cur_align) / size_of::<T>()) as isize);
485 }
486 }
487 ptr
488}
489
490#[inline(never)]
499unsafe fn masked_kernel<T, K>(k: usize, alpha: T,
500 a: *const T,
501 b: *const T,
502 beta: T,
503 c: *mut T, rsc: isize, csc: isize,
504 rows: usize, cols: usize,
505 mask_buf: &mut [T])
506 where K: GemmKernel<Elem=T>, T: Element,
507{
508 K::kernel(k, alpha, a, b, T::zero(), mask_buf.as_mut_ptr(), 1, K::MR as isize);
510 c_to_masked_ab_beta_c::<_, K>(beta, c, rsc, csc, rows, cols, &*mask_buf);
511}
512
513#[inline]
517unsafe fn c_to_masked_ab_beta_c<T, K>(beta: T,
518 c: *mut T, rsc: isize, csc: isize,
519 rows: usize, cols: usize,
520 mask_buf: &[T])
521 where K: GemmKernel<Elem=T>, T: Element,
522{
523 let mr = K::MR;
526 let nr = K::NR;
527 let mut ab = mask_buf.as_ptr();
528 for j in 0..nr {
529 for i in 0..mr {
530 if i < rows && j < cols {
531 let cptr = c.stride_offset(rsc, i)
532 .stride_offset(csc, j);
533 if beta.is_zero() {
534 *cptr = *ab; } else {
536 (*cptr).mul_assign(beta);
537 (*cptr).add_assign(*ab);
538 }
539 }
540 ab.inc();
541 }
542 }
543}
544
545#[inline(never)]
547unsafe fn c_to_beta_c<T>(m: usize, n: usize, beta: T,
548 c: *mut T, rsc: isize, csc: isize)
549 where T: Element
550{
551 for i in 0..m {
552 for j in 0..n {
553 let cptr = c.stride_offset(rsc, i)
554 .stride_offset(csc, j);
555 if beta.is_zero() {
556 *cptr = T::zero(); } else {
558 (*cptr).mul_assign(beta);
559 }
560 }
561 }
562}