1#[cfg(feature = "std")]
12use matrixmultiply;
13use num::{One, Zero};
14use simba::scalar::{ClosedAdd, ClosedMul};
15#[cfg(feature = "std")]
16use std::mem;
17
18use crate::base::constraint::{
19 AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
20};
21use crate::base::dimension::{Dim, Dynamic, U1};
22use crate::base::storage::{RawStorage, RawStorageMut};
23use crate::base::uninit::InitStatus;
24use crate::base::{Matrix, Scalar, Vector};
25use std::any::TypeId;
26
27#[allow(clippy::too_many_arguments)]
31unsafe fn array_axcpy<Status, T>(
32 _: Status,
33 y: &mut [Status::Value],
34 a: T,
35 x: &[T],
36 c: T,
37 beta: T,
38 stride1: usize,
39 stride2: usize,
40 len: usize,
41) where
42 Status: InitStatus<T>,
43 T: Scalar + Zero + ClosedAdd + ClosedMul,
44{
45 for i in 0..len {
46 let y = Status::assume_init_mut(y.get_unchecked_mut(i * stride1));
47 *y =
48 a.clone() * x.get_unchecked(i * stride2).clone() * c.clone() + beta.clone() * y.clone();
49 }
50}
51
52fn array_axc<Status, T>(
53 _: Status,
54 y: &mut [Status::Value],
55 a: T,
56 x: &[T],
57 c: T,
58 stride1: usize,
59 stride2: usize,
60 len: usize,
61) where
62 Status: InitStatus<T>,
63 T: Scalar + Zero + ClosedAdd + ClosedMul,
64{
65 for i in 0..len {
66 unsafe {
67 Status::init(
68 y.get_unchecked_mut(i * stride1),
69 a.clone() * x.get_unchecked(i * stride2).clone() * c.clone(),
70 );
71 }
72 }
73}
74
75#[inline(always)]
82#[allow(clippy::many_single_char_names)]
83pub unsafe fn axcpy_uninit<Status, T, D1: Dim, D2: Dim, SA, SB>(
84 status: Status,
85 y: &mut Vector<Status::Value, D1, SA>,
86 a: T,
87 x: &Vector<T, D2, SB>,
88 c: T,
89 b: T,
90) where
91 T: Scalar + Zero + ClosedAdd + ClosedMul,
92 SA: RawStorageMut<Status::Value, D1>,
93 SB: RawStorage<T, D2>,
94 ShapeConstraint: DimEq<D1, D2>,
95 Status: InitStatus<T>,
96{
97 assert_eq!(y.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
98
99 let rstride1 = y.strides().0;
100 let rstride2 = x.strides().0;
101
102 let y = y.data.as_mut_slice_unchecked();
105 let x = x.data.as_slice_unchecked();
106
107 if !b.is_zero() {
108 array_axcpy(status, y, a, x, c, b, rstride1, rstride2, x.len());
109 } else {
110 array_axc(status, y, a, x, c, rstride1, rstride2, x.len());
111 }
112}
113
114#[inline(always)]
122pub unsafe fn gemv_uninit<Status, T, D1: Dim, R2: Dim, C2: Dim, D3: Dim, SA, SB, SC>(
123 status: Status,
124 y: &mut Vector<Status::Value, D1, SA>,
125 alpha: T,
126 a: &Matrix<T, R2, C2, SB>,
127 x: &Vector<T, D3, SC>,
128 beta: T,
129) where
130 Status: InitStatus<T>,
131 T: Scalar + Zero + One + ClosedAdd + ClosedMul,
132 SA: RawStorageMut<Status::Value, D1>,
133 SB: RawStorage<T, R2, C2>,
134 SC: RawStorage<T, D3>,
135 ShapeConstraint: DimEq<D1, R2> + AreMultipliable<R2, C2, D3, U1>,
136{
137 let dim1 = y.nrows();
138 let (nrows2, ncols2) = a.shape();
139 let dim3 = x.nrows();
140
141 assert!(
142 ncols2 == dim3 && dim1 == nrows2,
143 "Gemv: dimensions mismatch."
144 );
145
146 if ncols2 == 0 {
147 if beta.is_zero() {
148 y.apply(|e| Status::init(e, T::zero()));
149 } else {
150 y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
152 }
153 return;
154 }
155
156 let col2 = a.column(0);
158 let val = x.vget_unchecked(0).clone();
159
160 axcpy_uninit(status, y, alpha.clone(), &col2, val, beta);
162
163 for j in 1..ncols2 {
164 let col2 = a.column(j);
165 let val = x.vget_unchecked(j).clone();
166
167 axcpy_uninit(status, y, alpha.clone(), &col2, val, T::one());
169 }
170}
171
172#[inline(always)]
180pub unsafe fn gemm_uninit<
181 Status,
182 T,
183 R1: Dim,
184 C1: Dim,
185 R2: Dim,
186 C2: Dim,
187 R3: Dim,
188 C3: Dim,
189 SA,
190 SB,
191 SC,
192>(
193 status: Status,
194 y: &mut Matrix<Status::Value, R1, C1, SA>,
195 alpha: T,
196 a: &Matrix<T, R2, C2, SB>,
197 b: &Matrix<T, R3, C3, SC>,
198 beta: T,
199) where
200 Status: InitStatus<T>,
201 T: Scalar + Zero + One + ClosedAdd + ClosedMul,
202 SA: RawStorageMut<Status::Value, R1, C1>,
203 SB: RawStorage<T, R2, C2>,
204 SC: RawStorage<T, R3, C3>,
205 ShapeConstraint:
206 SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C3> + AreMultipliable<R2, C2, R3, C3>,
207{
208 let ncols1 = y.ncols();
209
210 #[cfg(feature = "std")]
211 {
212 if R1::is::<Dynamic>()
217 || C1::is::<Dynamic>()
218 || R2::is::<Dynamic>()
219 || C2::is::<Dynamic>()
220 || R3::is::<Dynamic>()
221 || C3::is::<Dynamic>()
222 {
223 let nrows1 = y.nrows();
225 let (nrows2, ncols2) = a.shape();
226 let (nrows3, ncols3) = b.shape();
227
228 const SMALL_DIM: usize = 5;
230
231 if nrows1 > SMALL_DIM && ncols1 > SMALL_DIM && nrows2 > SMALL_DIM && ncols2 > SMALL_DIM
232 {
233 assert_eq!(
234 ncols2, nrows3,
235 "gemm: dimensions mismatch for multiplication."
236 );
237 assert_eq!(
238 (nrows1, ncols1),
239 (nrows2, ncols3),
240 "gemm: dimensions mismatch for addition."
241 );
242
243 if ncols2 == 0 {
248 if beta.is_zero() {
252 y.apply(|e| Status::init(e, T::zero()));
253 } else {
254 y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
256 }
257 return;
258 }
259
260 if TypeId::of::<T>() == TypeId::of::<f32>() {
261 let (rsa, csa) = a.strides();
262 let (rsb, csb) = b.strides();
263 let (rsc, csc) = y.strides();
264
265 matrixmultiply::sgemm(
266 nrows2,
267 ncols2,
268 ncols3,
269 mem::transmute_copy(&alpha),
270 a.data.ptr() as *const f32,
271 rsa as isize,
272 csa as isize,
273 b.data.ptr() as *const f32,
274 rsb as isize,
275 csb as isize,
276 mem::transmute_copy(&beta),
277 y.data.ptr_mut() as *mut f32,
278 rsc as isize,
279 csc as isize,
280 );
281 return;
282 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
283 let (rsa, csa) = a.strides();
284 let (rsb, csb) = b.strides();
285 let (rsc, csc) = y.strides();
286
287 matrixmultiply::dgemm(
288 nrows2,
289 ncols2,
290 ncols3,
291 mem::transmute_copy(&alpha),
292 a.data.ptr() as *const f64,
293 rsa as isize,
294 csa as isize,
295 b.data.ptr() as *const f64,
296 rsb as isize,
297 csb as isize,
298 mem::transmute_copy(&beta),
299 y.data.ptr_mut() as *mut f64,
300 rsc as isize,
301 csc as isize,
302 );
303 return;
304 }
305 }
306 }
307 }
308
309 for j1 in 0..ncols1 {
310 gemv_uninit(
313 status,
314 &mut y.column_mut(j1),
315 alpha.clone(),
316 a,
317 &b.column(j1),
318 beta.clone(),
319 );
320 }
321}