nalgebra/base/blas.rs
1use crate::{RawStorage, SimdComplexField};
2use num::{One, Zero};
3use simba::scalar::{ClosedAdd, ClosedMul};
4
5use crate::base::allocator::Allocator;
6use crate::base::blas_uninit::{axcpy_uninit, gemm_uninit, gemv_uninit};
7use crate::base::constraint::{
8 AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
9};
10use crate::base::dimension::{Const, Dim, Dynamic, U1, U2, U3, U4};
11use crate::base::storage::{Storage, StorageMut};
12use crate::base::uninit::Init;
13use crate::base::{
14 DVectorSlice, DefaultAllocator, Matrix, Scalar, SquareMatrix, Vector, VectorSlice,
15};
16
17/// # Dot/scalar product
18impl<T, R: Dim, C: Dim, S: RawStorage<T, R, C>> Matrix<T, R, C, S>
19where
20 T: Scalar + Zero + ClosedAdd + ClosedMul,
21{
22 #[inline(always)]
23 fn dotx<R2: Dim, C2: Dim, SB>(
24 &self,
25 rhs: &Matrix<T, R2, C2, SB>,
26 conjugate: impl Fn(T) -> T,
27 ) -> T
28 where
29 SB: RawStorage<T, R2, C2>,
30 ShapeConstraint: DimEq<R, R2> + DimEq<C, C2>,
31 {
32 assert!(
33 self.nrows() == rhs.nrows(),
34 "Dot product dimensions mismatch for shapes {:?} and {:?}: left rows != right rows.",
35 self.shape(),
36 rhs.shape(),
37 );
38
39 assert!(
40 self.ncols() == rhs.ncols(),
41 "Dot product dimensions mismatch for shapes {:?} and {:?}: left cols != right cols.",
42 self.shape(),
43 rhs.shape(),
44 );
45
46 // So we do some special cases for common fixed-size vectors of dimension lower than 8
47 // because the `for` loop below won't be very efficient on those.
48 if (R::is::<U2>() || R2::is::<U2>()) && (C::is::<U1>() || C2::is::<U1>()) {
49 unsafe {
50 let a = conjugate(self.get_unchecked((0, 0)).clone())
51 * rhs.get_unchecked((0, 0)).clone();
52 let b = conjugate(self.get_unchecked((1, 0)).clone())
53 * rhs.get_unchecked((1, 0)).clone();
54
55 return a + b;
56 }
57 }
58 if (R::is::<U3>() || R2::is::<U3>()) && (C::is::<U1>() || C2::is::<U1>()) {
59 unsafe {
60 let a = conjugate(self.get_unchecked((0, 0)).clone())
61 * rhs.get_unchecked((0, 0)).clone();
62 let b = conjugate(self.get_unchecked((1, 0)).clone())
63 * rhs.get_unchecked((1, 0)).clone();
64 let c = conjugate(self.get_unchecked((2, 0)).clone())
65 * rhs.get_unchecked((2, 0)).clone();
66
67 return a + b + c;
68 }
69 }
70 if (R::is::<U4>() || R2::is::<U4>()) && (C::is::<U1>() || C2::is::<U1>()) {
71 unsafe {
72 let mut a = conjugate(self.get_unchecked((0, 0)).clone())
73 * rhs.get_unchecked((0, 0)).clone();
74 let mut b = conjugate(self.get_unchecked((1, 0)).clone())
75 * rhs.get_unchecked((1, 0)).clone();
76 let c = conjugate(self.get_unchecked((2, 0)).clone())
77 * rhs.get_unchecked((2, 0)).clone();
78 let d = conjugate(self.get_unchecked((3, 0)).clone())
79 * rhs.get_unchecked((3, 0)).clone();
80
81 a += c;
82 b += d;
83
84 return a + b;
85 }
86 }
87
88 // All this is inspired from the "unrolled version" discussed in:
89 // https://blog.theincredibleholk.org/blog/2012/12/10/optimizing-dot-product/
90 //
91 // And this comment from bluss:
92 // https://users.rust-lang.org/t/how-to-zip-two-slices-efficiently/2048/12
93 let mut res = T::zero();
94
95 // We have to define them outside of the loop (and not inside at first assignment)
96 // otherwise vectorization won't kick in for some reason.
97 let mut acc0;
98 let mut acc1;
99 let mut acc2;
100 let mut acc3;
101 let mut acc4;
102 let mut acc5;
103 let mut acc6;
104 let mut acc7;
105
106 for j in 0..self.ncols() {
107 let mut i = 0;
108
109 acc0 = T::zero();
110 acc1 = T::zero();
111 acc2 = T::zero();
112 acc3 = T::zero();
113 acc4 = T::zero();
114 acc5 = T::zero();
115 acc6 = T::zero();
116 acc7 = T::zero();
117
118 while self.nrows() - i >= 8 {
119 acc0 += unsafe {
120 conjugate(self.get_unchecked((i, j)).clone())
121 * rhs.get_unchecked((i, j)).clone()
122 };
123 acc1 += unsafe {
124 conjugate(self.get_unchecked((i + 1, j)).clone())
125 * rhs.get_unchecked((i + 1, j)).clone()
126 };
127 acc2 += unsafe {
128 conjugate(self.get_unchecked((i + 2, j)).clone())
129 * rhs.get_unchecked((i + 2, j)).clone()
130 };
131 acc3 += unsafe {
132 conjugate(self.get_unchecked((i + 3, j)).clone())
133 * rhs.get_unchecked((i + 3, j)).clone()
134 };
135 acc4 += unsafe {
136 conjugate(self.get_unchecked((i + 4, j)).clone())
137 * rhs.get_unchecked((i + 4, j)).clone()
138 };
139 acc5 += unsafe {
140 conjugate(self.get_unchecked((i + 5, j)).clone())
141 * rhs.get_unchecked((i + 5, j)).clone()
142 };
143 acc6 += unsafe {
144 conjugate(self.get_unchecked((i + 6, j)).clone())
145 * rhs.get_unchecked((i + 6, j)).clone()
146 };
147 acc7 += unsafe {
148 conjugate(self.get_unchecked((i + 7, j)).clone())
149 * rhs.get_unchecked((i + 7, j)).clone()
150 };
151 i += 8;
152 }
153
154 res += acc0 + acc4;
155 res += acc1 + acc5;
156 res += acc2 + acc6;
157 res += acc3 + acc7;
158
159 for k in i..self.nrows() {
160 res += unsafe {
161 conjugate(self.get_unchecked((k, j)).clone())
162 * rhs.get_unchecked((k, j)).clone()
163 }
164 }
165 }
166
167 res
168 }
169
170 /// The dot product between two vectors or matrices (seen as vectors).
171 ///
172 /// This is equal to `self.transpose() * rhs`. For the sesquilinear complex dot product, use
173 /// `self.dotc(rhs)`.
174 ///
175 /// Note that this is **not** the matrix multiplication as in, e.g., numpy. For matrix
176 /// multiplication, use one of: `.gemm`, `.mul_to`, `.mul`, the `*` operator.
177 ///
178 /// # Examples:
179 ///
180 /// ```
181 /// # use nalgebra::{Vector3, Matrix2x3};
182 /// let vec1 = Vector3::new(1.0, 2.0, 3.0);
183 /// let vec2 = Vector3::new(0.1, 0.2, 0.3);
184 /// assert_eq!(vec1.dot(&vec2), 1.4);
185 ///
186 /// let mat1 = Matrix2x3::new(1.0, 2.0, 3.0,
187 /// 4.0, 5.0, 6.0);
188 /// let mat2 = Matrix2x3::new(0.1, 0.2, 0.3,
189 /// 0.4, 0.5, 0.6);
190 /// assert_eq!(mat1.dot(&mat2), 9.1);
191 /// ```
192 ///
193 #[inline]
194 #[must_use]
195 pub fn dot<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> T
196 where
197 SB: RawStorage<T, R2, C2>,
198 ShapeConstraint: DimEq<R, R2> + DimEq<C, C2>,
199 {
200 self.dotx(rhs, |e| e)
201 }
202
203 /// The conjugate-linear dot product between two vectors or matrices (seen as vectors).
204 ///
205 /// This is equal to `self.adjoint() * rhs`.
206 /// For real vectors, this is identical to `self.dot(&rhs)`.
207 /// Note that this is **not** the matrix multiplication as in, e.g., numpy. For matrix
208 /// multiplication, use one of: `.gemm`, `.mul_to`, `.mul`, the `*` operator.
209 ///
210 /// # Examples:
211 ///
212 /// ```
213 /// # use nalgebra::{Vector2, Complex};
214 /// let vec1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
215 /// let vec2 = Vector2::new(Complex::new(0.4, 0.3), Complex::new(0.2, 0.1));
216 /// assert_eq!(vec1.dotc(&vec2), Complex::new(2.0, -1.0));
217 ///
218 /// // Note that for complex vectors, we generally have:
219 /// // vec1.dotc(&vec2) != vec2.dot(&vec2)
220 /// assert_ne!(vec1.dotc(&vec2), vec1.dot(&vec2));
221 /// ```
222 #[inline]
223 #[must_use]
224 pub fn dotc<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> T
225 where
226 T: SimdComplexField,
227 SB: RawStorage<T, R2, C2>,
228 ShapeConstraint: DimEq<R, R2> + DimEq<C, C2>,
229 {
230 self.dotx(rhs, T::simd_conjugate)
231 }
232
233 /// The dot product between the transpose of `self` and `rhs`.
234 ///
235 /// # Examples:
236 ///
237 /// ```
238 /// # use nalgebra::{Vector3, RowVector3, Matrix2x3, Matrix3x2};
239 /// let vec1 = Vector3::new(1.0, 2.0, 3.0);
240 /// let vec2 = RowVector3::new(0.1, 0.2, 0.3);
241 /// assert_eq!(vec1.tr_dot(&vec2), 1.4);
242 ///
243 /// let mat1 = Matrix2x3::new(1.0, 2.0, 3.0,
244 /// 4.0, 5.0, 6.0);
245 /// let mat2 = Matrix3x2::new(0.1, 0.4,
246 /// 0.2, 0.5,
247 /// 0.3, 0.6);
248 /// assert_eq!(mat1.tr_dot(&mat2), 9.1);
249 /// ```
250 #[inline]
251 #[must_use]
252 pub fn tr_dot<R2: Dim, C2: Dim, SB>(&self, rhs: &Matrix<T, R2, C2, SB>) -> T
253 where
254 SB: RawStorage<T, R2, C2>,
255 ShapeConstraint: DimEq<C, R2> + DimEq<R, C2>,
256 {
257 let (nrows, ncols) = self.shape();
258 assert_eq!(
259 (ncols, nrows),
260 rhs.shape(),
261 "Transposed dot product dimension mismatch."
262 );
263
264 let mut res = T::zero();
265
266 for j in 0..self.nrows() {
267 for i in 0..self.ncols() {
268 res += unsafe {
269 self.get_unchecked((j, i)).clone() * rhs.get_unchecked((i, j)).clone()
270 }
271 }
272 }
273
274 res
275 }
276}
277
278/// # BLAS functions
279impl<T, D: Dim, S> Vector<T, D, S>
280where
281 T: Scalar + Zero + ClosedAdd + ClosedMul,
282 S: StorageMut<T, D>,
283{
284 /// Computes `self = a * x * c + b * self`.
285 ///
286 /// If `b` is zero, `self` is never read from.
287 ///
288 /// # Examples:
289 ///
290 /// ```
291 /// # use nalgebra::Vector3;
292 /// let mut vec1 = Vector3::new(1.0, 2.0, 3.0);
293 /// let vec2 = Vector3::new(0.1, 0.2, 0.3);
294 /// vec1.axcpy(5.0, &vec2, 2.0, 5.0);
295 /// assert_eq!(vec1, Vector3::new(6.0, 12.0, 18.0));
296 /// ```
297 #[inline]
298 #[allow(clippy::many_single_char_names)]
299 pub fn axcpy<D2: Dim, SB>(&mut self, a: T, x: &Vector<T, D2, SB>, c: T, b: T)
300 where
301 SB: Storage<T, D2>,
302 ShapeConstraint: DimEq<D, D2>,
303 {
304 unsafe { axcpy_uninit(Init, self, a, x, c, b) };
305 }
306
307 /// Computes `self = a * x + b * self`.
308 ///
309 /// If `b` is zero, `self` is never read from.
310 ///
311 /// # Examples:
312 ///
313 /// ```
314 /// # use nalgebra::Vector3;
315 /// let mut vec1 = Vector3::new(1.0, 2.0, 3.0);
316 /// let vec2 = Vector3::new(0.1, 0.2, 0.3);
317 /// vec1.axpy(10.0, &vec2, 5.0);
318 /// assert_eq!(vec1, Vector3::new(6.0, 12.0, 18.0));
319 /// ```
320 #[inline]
321 pub fn axpy<D2: Dim, SB>(&mut self, a: T, x: &Vector<T, D2, SB>, b: T)
322 where
323 T: One,
324 SB: Storage<T, D2>,
325 ShapeConstraint: DimEq<D, D2>,
326 {
327 assert_eq!(self.nrows(), x.nrows(), "Axpy: mismatched vector shapes.");
328 self.axcpy(a, x, T::one(), b)
329 }
330
331 /// Computes `self = alpha * a * x + beta * self`, where `a` is a matrix, `x` a vector, and
332 /// `alpha, beta` two scalars.
333 ///
334 /// If `beta` is zero, `self` is never read.
335 ///
336 /// # Examples:
337 ///
338 /// ```
339 /// # use nalgebra::{Matrix2, Vector2};
340 /// let mut vec1 = Vector2::new(1.0, 2.0);
341 /// let vec2 = Vector2::new(0.1, 0.2);
342 /// let mat = Matrix2::new(1.0, 2.0,
343 /// 3.0, 4.0);
344 /// vec1.gemv(10.0, &mat, &vec2, 5.0);
345 /// assert_eq!(vec1, Vector2::new(10.0, 21.0));
346 /// ```
347 #[inline]
348 pub fn gemv<R2: Dim, C2: Dim, D3: Dim, SB, SC>(
349 &mut self,
350 alpha: T,
351 a: &Matrix<T, R2, C2, SB>,
352 x: &Vector<T, D3, SC>,
353 beta: T,
354 ) where
355 T: One,
356 SB: Storage<T, R2, C2>,
357 SC: Storage<T, D3>,
358 ShapeConstraint: DimEq<D, R2> + AreMultipliable<R2, C2, D3, U1>,
359 {
360 // Safety: this is safe because we are passing Status == Init.
361 unsafe { gemv_uninit(Init, self, alpha, a, x, beta) }
362 }
363
364 #[inline(always)]
365 fn xxgemv<D2: Dim, D3: Dim, SB, SC>(
366 &mut self,
367 alpha: T,
368 a: &SquareMatrix<T, D2, SB>,
369 x: &Vector<T, D3, SC>,
370 beta: T,
371 dot: impl Fn(
372 &DVectorSlice<'_, T, SB::RStride, SB::CStride>,
373 &DVectorSlice<'_, T, SC::RStride, SC::CStride>,
374 ) -> T,
375 ) where
376 T: One,
377 SB: Storage<T, D2, D2>,
378 SC: Storage<T, D3>,
379 ShapeConstraint: DimEq<D, D2> + AreMultipliable<D2, D2, D3, U1>,
380 {
381 let dim1 = self.nrows();
382 let dim2 = a.nrows();
383 let dim3 = x.nrows();
384
385 assert!(
386 a.is_square(),
387 "Symmetric cgemv: the input matrix must be square."
388 );
389 assert!(
390 dim2 == dim3 && dim1 == dim2,
391 "Symmetric cgemv: dimensions mismatch."
392 );
393
394 if dim2 == 0 {
395 return;
396 }
397
398 // TODO: avoid bound checks.
399 let col2 = a.column(0);
400 let val = unsafe { x.vget_unchecked(0).clone() };
401 self.axpy(alpha.clone() * val, &col2, beta);
402 self[0] += alpha.clone() * dot(&a.slice_range(1.., 0), &x.rows_range(1..));
403
404 for j in 1..dim2 {
405 let col2 = a.column(j);
406 let dot = dot(&col2.rows_range(j..), &x.rows_range(j..));
407
408 let val;
409 unsafe {
410 val = x.vget_unchecked(j).clone();
411 *self.vget_unchecked_mut(j) += alpha.clone() * dot;
412 }
413 self.rows_range_mut(j + 1..).axpy(
414 alpha.clone() * val,
415 &col2.rows_range(j + 1..),
416 T::one(),
417 );
418 }
419 }
420
421 /// Computes `self = alpha * a * x + beta * self`, where `a` is a **symmetric** matrix, `x` a
422 /// vector, and `alpha, beta` two scalars.
423 ///
424 /// For hermitian matrices, use `.hegemv` instead.
425 /// If `beta` is zero, `self` is never read. If `self` is read, only its lower-triangular part
426 /// (including the diagonal) is actually read.
427 ///
428 /// # Examples:
429 ///
430 /// ```
431 /// # use nalgebra::{Matrix2, Vector2};
432 /// let mat = Matrix2::new(1.0, 2.0,
433 /// 2.0, 4.0);
434 /// let mut vec1 = Vector2::new(1.0, 2.0);
435 /// let vec2 = Vector2::new(0.1, 0.2);
436 /// vec1.sygemv(10.0, &mat, &vec2, 5.0);
437 /// assert_eq!(vec1, Vector2::new(10.0, 20.0));
438 ///
439 ///
440 /// // The matrix upper-triangular elements can be garbage because it is never
441 /// // read by this method. Therefore, it is not necessary for the caller to
442 /// // fill the matrix struct upper-triangle.
443 /// let mat = Matrix2::new(1.0, 9999999.9999999,
444 /// 2.0, 4.0);
445 /// let mut vec1 = Vector2::new(1.0, 2.0);
446 /// vec1.sygemv(10.0, &mat, &vec2, 5.0);
447 /// assert_eq!(vec1, Vector2::new(10.0, 20.0));
448 /// ```
449 #[inline]
450 pub fn sygemv<D2: Dim, D3: Dim, SB, SC>(
451 &mut self,
452 alpha: T,
453 a: &SquareMatrix<T, D2, SB>,
454 x: &Vector<T, D3, SC>,
455 beta: T,
456 ) where
457 T: One,
458 SB: Storage<T, D2, D2>,
459 SC: Storage<T, D3>,
460 ShapeConstraint: DimEq<D, D2> + AreMultipliable<D2, D2, D3, U1>,
461 {
462 self.xxgemv(alpha, a, x, beta, |a, b| a.dot(b))
463 }
464
465 /// Computes `self = alpha * a * x + beta * self`, where `a` is an **hermitian** matrix, `x` a
466 /// vector, and `alpha, beta` two scalars.
467 ///
468 /// If `beta` is zero, `self` is never read. If `self` is read, only its lower-triangular part
469 /// (including the diagonal) is actually read.
470 ///
471 /// # Examples:
472 ///
473 /// ```
474 /// # use nalgebra::{Matrix2, Vector2, Complex};
475 /// let mat = Matrix2::new(Complex::new(1.0, 0.0), Complex::new(2.0, -0.1),
476 /// Complex::new(2.0, 1.0), Complex::new(4.0, 0.0));
477 /// let mut vec1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
478 /// let vec2 = Vector2::new(Complex::new(0.1, 0.2), Complex::new(0.3, 0.4));
479 /// vec1.sygemv(Complex::new(10.0, 20.0), &mat, &vec2, Complex::new(5.0, 15.0));
480 /// assert_eq!(vec1, Vector2::new(Complex::new(-48.0, 44.0), Complex::new(-75.0, 110.0)));
481 ///
482 ///
483 /// // The matrix upper-triangular elements can be garbage because it is never
484 /// // read by this method. Therefore, it is not necessary for the caller to
485 /// // fill the matrix struct upper-triangle.
486 ///
487 /// let mat = Matrix2::new(Complex::new(1.0, 0.0), Complex::new(99999999.9, 999999999.9),
488 /// Complex::new(2.0, 1.0), Complex::new(4.0, 0.0));
489 /// let mut vec1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
490 /// let vec2 = Vector2::new(Complex::new(0.1, 0.2), Complex::new(0.3, 0.4));
491 /// vec1.sygemv(Complex::new(10.0, 20.0), &mat, &vec2, Complex::new(5.0, 15.0));
492 /// assert_eq!(vec1, Vector2::new(Complex::new(-48.0, 44.0), Complex::new(-75.0, 110.0)));
493 /// ```
494 #[inline]
495 pub fn hegemv<D2: Dim, D3: Dim, SB, SC>(
496 &mut self,
497 alpha: T,
498 a: &SquareMatrix<T, D2, SB>,
499 x: &Vector<T, D3, SC>,
500 beta: T,
501 ) where
502 T: SimdComplexField,
503 SB: Storage<T, D2, D2>,
504 SC: Storage<T, D3>,
505 ShapeConstraint: DimEq<D, D2> + AreMultipliable<D2, D2, D3, U1>,
506 {
507 self.xxgemv(alpha, a, x, beta, |a, b| a.dotc(b))
508 }
509
510 #[inline(always)]
511 fn gemv_xx<R2: Dim, C2: Dim, D3: Dim, SB, SC>(
512 &mut self,
513 alpha: T,
514 a: &Matrix<T, R2, C2, SB>,
515 x: &Vector<T, D3, SC>,
516 beta: T,
517 dot: impl Fn(&VectorSlice<'_, T, R2, SB::RStride, SB::CStride>, &Vector<T, D3, SC>) -> T,
518 ) where
519 T: One,
520 SB: Storage<T, R2, C2>,
521 SC: Storage<T, D3>,
522 ShapeConstraint: DimEq<D, C2> + AreMultipliable<C2, R2, D3, U1>,
523 {
524 let dim1 = self.nrows();
525 let (nrows2, ncols2) = a.shape();
526 let dim3 = x.nrows();
527
528 assert!(
529 nrows2 == dim3 && dim1 == ncols2,
530 "Gemv: dimensions mismatch."
531 );
532
533 if ncols2 == 0 {
534 return;
535 }
536
537 if beta.is_zero() {
538 for j in 0..ncols2 {
539 let val = unsafe { self.vget_unchecked_mut(j) };
540 *val = alpha.clone() * dot(&a.column(j), x)
541 }
542 } else {
543 for j in 0..ncols2 {
544 let val = unsafe { self.vget_unchecked_mut(j) };
545 *val = alpha.clone() * dot(&a.column(j), x) + beta.clone() * val.clone();
546 }
547 }
548 }
549
550 /// Computes `self = alpha * a.transpose() * x + beta * self`, where `a` is a matrix, `x` a vector, and
551 /// `alpha, beta` two scalars.
552 ///
553 /// If `beta` is zero, `self` is never read.
554 ///
555 /// # Examples:
556 ///
557 /// ```
558 /// # use nalgebra::{Matrix2, Vector2};
559 /// let mat = Matrix2::new(1.0, 3.0,
560 /// 2.0, 4.0);
561 /// let mut vec1 = Vector2::new(1.0, 2.0);
562 /// let vec2 = Vector2::new(0.1, 0.2);
563 /// let expected = mat.transpose() * vec2 * 10.0 + vec1 * 5.0;
564 ///
565 /// vec1.gemv_tr(10.0, &mat, &vec2, 5.0);
566 /// assert_eq!(vec1, expected);
567 /// ```
568 #[inline]
569 pub fn gemv_tr<R2: Dim, C2: Dim, D3: Dim, SB, SC>(
570 &mut self,
571 alpha: T,
572 a: &Matrix<T, R2, C2, SB>,
573 x: &Vector<T, D3, SC>,
574 beta: T,
575 ) where
576 T: One,
577 SB: Storage<T, R2, C2>,
578 SC: Storage<T, D3>,
579 ShapeConstraint: DimEq<D, C2> + AreMultipliable<C2, R2, D3, U1>,
580 {
581 self.gemv_xx(alpha, a, x, beta, |a, b| a.dot(b))
582 }
583
584 /// Computes `self = alpha * a.adjoint() * x + beta * self`, where `a` is a matrix, `x` a vector, and
585 /// `alpha, beta` two scalars.
586 ///
587 /// For real matrices, this is the same as `.gemv_tr`.
588 /// If `beta` is zero, `self` is never read.
589 ///
590 /// # Examples:
591 ///
592 /// ```
593 /// # use nalgebra::{Matrix2, Vector2, Complex};
594 /// let mat = Matrix2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0),
595 /// Complex::new(5.0, 6.0), Complex::new(7.0, 8.0));
596 /// let mut vec1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
597 /// let vec2 = Vector2::new(Complex::new(0.1, 0.2), Complex::new(0.3, 0.4));
598 /// let expected = mat.adjoint() * vec2 * Complex::new(10.0, 20.0) + vec1 * Complex::new(5.0, 15.0);
599 ///
600 /// vec1.gemv_ad(Complex::new(10.0, 20.0), &mat, &vec2, Complex::new(5.0, 15.0));
601 /// assert_eq!(vec1, expected);
602 /// ```
603 #[inline]
604 pub fn gemv_ad<R2: Dim, C2: Dim, D3: Dim, SB, SC>(
605 &mut self,
606 alpha: T,
607 a: &Matrix<T, R2, C2, SB>,
608 x: &Vector<T, D3, SC>,
609 beta: T,
610 ) where
611 T: SimdComplexField,
612 SB: Storage<T, R2, C2>,
613 SC: Storage<T, D3>,
614 ShapeConstraint: DimEq<D, C2> + AreMultipliable<C2, R2, D3, U1>,
615 {
616 self.gemv_xx(alpha, a, x, beta, |a, b| a.dotc(b))
617 }
618}
619
620impl<T, R1: Dim, C1: Dim, S: StorageMut<T, R1, C1>> Matrix<T, R1, C1, S>
621where
622 T: Scalar + Zero + ClosedAdd + ClosedMul,
623{
624 #[inline(always)]
625 fn gerx<D2: Dim, D3: Dim, SB, SC>(
626 &mut self,
627 alpha: T,
628 x: &Vector<T, D2, SB>,
629 y: &Vector<T, D3, SC>,
630 beta: T,
631 conjugate: impl Fn(T) -> T,
632 ) where
633 T: One,
634 SB: Storage<T, D2>,
635 SC: Storage<T, D3>,
636 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
637 {
638 let (nrows1, ncols1) = self.shape();
639 let dim2 = x.nrows();
640 let dim3 = y.nrows();
641
642 assert!(
643 nrows1 == dim2 && ncols1 == dim3,
644 "ger: dimensions mismatch."
645 );
646
647 for j in 0..ncols1 {
648 // TODO: avoid bound checks.
649 let val = unsafe { conjugate(y.vget_unchecked(j).clone()) };
650 self.column_mut(j)
651 .axpy(alpha.clone() * val, x, beta.clone());
652 }
653 }
654
655 /// Computes `self = alpha * x * y.transpose() + beta * self`.
656 ///
657 /// If `beta` is zero, `self` is never read.
658 ///
659 /// # Examples:
660 ///
661 /// ```
662 /// # use nalgebra::{Matrix2x3, Vector2, Vector3};
663 /// let mut mat = Matrix2x3::repeat(4.0);
664 /// let vec1 = Vector2::new(1.0, 2.0);
665 /// let vec2 = Vector3::new(0.1, 0.2, 0.3);
666 /// let expected = vec1 * vec2.transpose() * 10.0 + mat * 5.0;
667 ///
668 /// mat.ger(10.0, &vec1, &vec2, 5.0);
669 /// assert_eq!(mat, expected);
670 /// ```
671 #[inline]
672 pub fn ger<D2: Dim, D3: Dim, SB, SC>(
673 &mut self,
674 alpha: T,
675 x: &Vector<T, D2, SB>,
676 y: &Vector<T, D3, SC>,
677 beta: T,
678 ) where
679 T: One,
680 SB: Storage<T, D2>,
681 SC: Storage<T, D3>,
682 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
683 {
684 self.gerx(alpha, x, y, beta, |e| e)
685 }
686
687 /// Computes `self = alpha * x * y.adjoint() + beta * self`.
688 ///
689 /// If `beta` is zero, `self` is never read.
690 ///
691 /// # Examples:
692 ///
693 /// ```
694 /// # #[macro_use] extern crate approx;
695 /// # use nalgebra::{Matrix2x3, Vector2, Vector3, Complex};
696 /// let mut mat = Matrix2x3::repeat(Complex::new(4.0, 5.0));
697 /// let vec1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
698 /// let vec2 = Vector3::new(Complex::new(0.6, 0.5), Complex::new(0.4, 0.5), Complex::new(0.2, 0.1));
699 /// let expected = vec1 * vec2.adjoint() * Complex::new(10.0, 20.0) + mat * Complex::new(5.0, 15.0);
700 ///
701 /// mat.gerc(Complex::new(10.0, 20.0), &vec1, &vec2, Complex::new(5.0, 15.0));
702 /// assert_eq!(mat, expected);
703 /// ```
704 #[inline]
705 pub fn gerc<D2: Dim, D3: Dim, SB, SC>(
706 &mut self,
707 alpha: T,
708 x: &Vector<T, D2, SB>,
709 y: &Vector<T, D3, SC>,
710 beta: T,
711 ) where
712 T: SimdComplexField,
713 SB: Storage<T, D2>,
714 SC: Storage<T, D3>,
715 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
716 {
717 self.gerx(alpha, x, y, beta, SimdComplexField::simd_conjugate)
718 }
719
720 /// Computes `self = alpha * a * b + beta * self`, where `a, b, self` are matrices.
721 /// `alpha` and `beta` are scalar.
722 ///
723 /// If `beta` is zero, `self` is never read.
724 ///
725 /// # Examples:
726 ///
727 /// ```
728 /// # #[macro_use] extern crate approx;
729 /// # use nalgebra::{Matrix2x3, Matrix3x4, Matrix2x4};
730 /// let mut mat1 = Matrix2x4::identity();
731 /// let mat2 = Matrix2x3::new(1.0, 2.0, 3.0,
732 /// 4.0, 5.0, 6.0);
733 /// let mat3 = Matrix3x4::new(0.1, 0.2, 0.3, 0.4,
734 /// 0.5, 0.6, 0.7, 0.8,
735 /// 0.9, 1.0, 1.1, 1.2);
736 /// let expected = mat2 * mat3 * 10.0 + mat1 * 5.0;
737 ///
738 /// mat1.gemm(10.0, &mat2, &mat3, 5.0);
739 /// assert_relative_eq!(mat1, expected);
740 /// ```
741 #[inline]
742 pub fn gemm<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(
743 &mut self,
744 alpha: T,
745 a: &Matrix<T, R2, C2, SB>,
746 b: &Matrix<T, R3, C3, SC>,
747 beta: T,
748 ) where
749 T: One,
750 SB: Storage<T, R2, C2>,
751 SC: Storage<T, R3, C3>,
752 ShapeConstraint: SameNumberOfRows<R1, R2>
753 + SameNumberOfColumns<C1, C3>
754 + AreMultipliable<R2, C2, R3, C3>,
755 {
756 // SAFETY: this is valid because our matrices are initialized and
757 // we are using status = Init.
758 unsafe { gemm_uninit(Init, self, alpha, a, b, beta) }
759 }
760
761 /// Computes `self = alpha * a.transpose() * b + beta * self`, where `a, b, self` are matrices.
762 /// `alpha` and `beta` are scalar.
763 ///
764 /// If `beta` is zero, `self` is never read.
765 ///
766 /// # Examples:
767 ///
768 /// ```
769 /// # #[macro_use] extern crate approx;
770 /// # use nalgebra::{Matrix3x2, Matrix3x4, Matrix2x4};
771 /// let mut mat1 = Matrix2x4::identity();
772 /// let mat2 = Matrix3x2::new(1.0, 4.0,
773 /// 2.0, 5.0,
774 /// 3.0, 6.0);
775 /// let mat3 = Matrix3x4::new(0.1, 0.2, 0.3, 0.4,
776 /// 0.5, 0.6, 0.7, 0.8,
777 /// 0.9, 1.0, 1.1, 1.2);
778 /// let expected = mat2.transpose() * mat3 * 10.0 + mat1 * 5.0;
779 ///
780 /// mat1.gemm_tr(10.0, &mat2, &mat3, 5.0);
781 /// assert_eq!(mat1, expected);
782 /// ```
783 #[inline]
784 pub fn gemm_tr<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(
785 &mut self,
786 alpha: T,
787 a: &Matrix<T, R2, C2, SB>,
788 b: &Matrix<T, R3, C3, SC>,
789 beta: T,
790 ) where
791 T: One,
792 SB: Storage<T, R2, C2>,
793 SC: Storage<T, R3, C3>,
794 ShapeConstraint: SameNumberOfRows<R1, C2>
795 + SameNumberOfColumns<C1, C3>
796 + AreMultipliable<C2, R2, R3, C3>,
797 {
798 let (nrows1, ncols1) = self.shape();
799 let (nrows2, ncols2) = a.shape();
800 let (nrows3, ncols3) = b.shape();
801
802 assert_eq!(
803 nrows2, nrows3,
804 "gemm: dimensions mismatch for multiplication."
805 );
806 assert_eq!(
807 (nrows1, ncols1),
808 (ncols2, ncols3),
809 "gemm: dimensions mismatch for addition."
810 );
811
812 for j1 in 0..ncols1 {
813 // TODO: avoid bound checks.
814 self.column_mut(j1)
815 .gemv_tr(alpha.clone(), a, &b.column(j1), beta.clone());
816 }
817 }
818
819 /// Computes `self = alpha * a.adjoint() * b + beta * self`, where `a, b, self` are matrices.
820 /// `alpha` and `beta` are scalar.
821 ///
822 /// If `beta` is zero, `self` is never read.
823 ///
824 /// # Examples:
825 ///
826 /// ```
827 /// # #[macro_use] extern crate approx;
828 /// # use nalgebra::{Matrix3x2, Matrix3x4, Matrix2x4, Complex};
829 /// let mut mat1 = Matrix2x4::identity();
830 /// let mat2 = Matrix3x2::new(Complex::new(1.0, 4.0), Complex::new(7.0, 8.0),
831 /// Complex::new(2.0, 5.0), Complex::new(9.0, 10.0),
832 /// Complex::new(3.0, 6.0), Complex::new(11.0, 12.0));
833 /// let mat3 = Matrix3x4::new(Complex::new(0.1, 1.3), Complex::new(0.2, 1.4), Complex::new(0.3, 1.5), Complex::new(0.4, 1.6),
834 /// Complex::new(0.5, 1.7), Complex::new(0.6, 1.8), Complex::new(0.7, 1.9), Complex::new(0.8, 2.0),
835 /// Complex::new(0.9, 2.1), Complex::new(1.0, 2.2), Complex::new(1.1, 2.3), Complex::new(1.2, 2.4));
836 /// let expected = mat2.adjoint() * mat3 * Complex::new(10.0, 20.0) + mat1 * Complex::new(5.0, 15.0);
837 ///
838 /// mat1.gemm_ad(Complex::new(10.0, 20.0), &mat2, &mat3, Complex::new(5.0, 15.0));
839 /// assert_eq!(mat1, expected);
840 /// ```
841 #[inline]
842 pub fn gemm_ad<R2: Dim, C2: Dim, R3: Dim, C3: Dim, SB, SC>(
843 &mut self,
844 alpha: T,
845 a: &Matrix<T, R2, C2, SB>,
846 b: &Matrix<T, R3, C3, SC>,
847 beta: T,
848 ) where
849 T: SimdComplexField,
850 SB: Storage<T, R2, C2>,
851 SC: Storage<T, R3, C3>,
852 ShapeConstraint: SameNumberOfRows<R1, C2>
853 + SameNumberOfColumns<C1, C3>
854 + AreMultipliable<C2, R2, R3, C3>,
855 {
856 let (nrows1, ncols1) = self.shape();
857 let (nrows2, ncols2) = a.shape();
858 let (nrows3, ncols3) = b.shape();
859
860 assert_eq!(
861 nrows2, nrows3,
862 "gemm: dimensions mismatch for multiplication."
863 );
864 assert_eq!(
865 (nrows1, ncols1),
866 (ncols2, ncols3),
867 "gemm: dimensions mismatch for addition."
868 );
869
870 for j1 in 0..ncols1 {
871 // TODO: avoid bound checks.
872 self.column_mut(j1)
873 .gemv_ad(alpha.clone(), a, &b.column(j1), beta.clone());
874 }
875 }
876}
877
878impl<T, R1: Dim, C1: Dim, S: StorageMut<T, R1, C1>> Matrix<T, R1, C1, S>
879where
880 T: Scalar + Zero + ClosedAdd + ClosedMul,
881{
882 #[inline(always)]
883 fn xxgerx<D2: Dim, D3: Dim, SB, SC>(
884 &mut self,
885 alpha: T,
886 x: &Vector<T, D2, SB>,
887 y: &Vector<T, D3, SC>,
888 beta: T,
889 conjugate: impl Fn(T) -> T,
890 ) where
891 T: One,
892 SB: Storage<T, D2>,
893 SC: Storage<T, D3>,
894 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
895 {
896 let dim1 = self.nrows();
897 let dim2 = x.nrows();
898 let dim3 = y.nrows();
899
900 assert!(
901 self.is_square(),
902 "Symmetric ger: the input matrix must be square."
903 );
904 assert!(dim1 == dim2 && dim1 == dim3, "ger: dimensions mismatch.");
905
906 for j in 0..dim1 {
907 let val = unsafe { conjugate(y.vget_unchecked(j).clone()) };
908 let subdim = Dynamic::new(dim1 - j);
909 // TODO: avoid bound checks.
910 self.generic_slice_mut((j, j), (subdim, Const::<1>)).axpy(
911 alpha.clone() * val,
912 &x.rows_range(j..),
913 beta.clone(),
914 );
915 }
916 }
917
918 /// Computes `self = alpha * x * y.transpose() + beta * self`, where `self` is a **symmetric**
919 /// matrix.
920 ///
921 /// If `beta` is zero, `self` is never read. The result is symmetric. Only the lower-triangular
922 /// (including the diagonal) part of `self` is read/written.
923 ///
924 /// # Examples:
925 ///
926 /// ```
927 /// # use nalgebra::{Matrix2, Vector2};
928 /// let mut mat = Matrix2::identity();
929 /// let vec1 = Vector2::new(1.0, 2.0);
930 /// let vec2 = Vector2::new(0.1, 0.2);
931 /// let expected = vec1 * vec2.transpose() * 10.0 + mat * 5.0;
932 /// mat.m12 = 99999.99999; // This component is on the upper-triangular part and will not be read/written.
933 ///
934 /// mat.ger_symm(10.0, &vec1, &vec2, 5.0);
935 /// assert_eq!(mat.lower_triangle(), expected.lower_triangle());
936 /// assert_eq!(mat.m12, 99999.99999); // This was untouched.
937 #[inline]
938 #[deprecated(note = "This is renamed `syger` to match the original BLAS terminology.")]
939 pub fn ger_symm<D2: Dim, D3: Dim, SB, SC>(
940 &mut self,
941 alpha: T,
942 x: &Vector<T, D2, SB>,
943 y: &Vector<T, D3, SC>,
944 beta: T,
945 ) where
946 T: One,
947 SB: Storage<T, D2>,
948 SC: Storage<T, D3>,
949 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
950 {
951 self.syger(alpha, x, y, beta)
952 }
953
954 /// Computes `self = alpha * x * y.transpose() + beta * self`, where `self` is a **symmetric**
955 /// matrix.
956 ///
957 /// For hermitian complex matrices, use `.hegerc` instead.
958 /// If `beta` is zero, `self` is never read. The result is symmetric. Only the lower-triangular
959 /// (including the diagonal) part of `self` is read/written.
960 ///
961 /// # Examples:
962 ///
963 /// ```
964 /// # use nalgebra::{Matrix2, Vector2};
965 /// let mut mat = Matrix2::identity();
966 /// let vec1 = Vector2::new(1.0, 2.0);
967 /// let vec2 = Vector2::new(0.1, 0.2);
968 /// let expected = vec1 * vec2.transpose() * 10.0 + mat * 5.0;
969 /// mat.m12 = 99999.99999; // This component is on the upper-triangular part and will not be read/written.
970 ///
971 /// mat.syger(10.0, &vec1, &vec2, 5.0);
972 /// assert_eq!(mat.lower_triangle(), expected.lower_triangle());
973 /// assert_eq!(mat.m12, 99999.99999); // This was untouched.
974 #[inline]
975 pub fn syger<D2: Dim, D3: Dim, SB, SC>(
976 &mut self,
977 alpha: T,
978 x: &Vector<T, D2, SB>,
979 y: &Vector<T, D3, SC>,
980 beta: T,
981 ) where
982 T: One,
983 SB: Storage<T, D2>,
984 SC: Storage<T, D3>,
985 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
986 {
987 self.xxgerx(alpha, x, y, beta, |e| e)
988 }
989
990 /// Computes `self = alpha * x * y.adjoint() + beta * self`, where `self` is an **hermitian**
991 /// matrix.
992 ///
993 /// If `beta` is zero, `self` is never read. The result is symmetric. Only the lower-triangular
994 /// (including the diagonal) part of `self` is read/written.
995 ///
996 /// # Examples:
997 ///
998 /// ```
999 /// # use nalgebra::{Matrix2, Vector2, Complex};
1000 /// let mut mat = Matrix2::identity();
1001 /// let vec1 = Vector2::new(Complex::new(1.0, 3.0), Complex::new(2.0, 4.0));
1002 /// let vec2 = Vector2::new(Complex::new(0.2, 0.4), Complex::new(0.1, 0.3));
1003 /// let expected = vec1 * vec2.adjoint() * Complex::new(10.0, 20.0) + mat * Complex::new(5.0, 15.0);
1004 /// mat.m12 = Complex::new(99999.99999, 88888.88888); // This component is on the upper-triangular part and will not be read/written.
1005 ///
1006 /// mat.hegerc(Complex::new(10.0, 20.0), &vec1, &vec2, Complex::new(5.0, 15.0));
1007 /// assert_eq!(mat.lower_triangle(), expected.lower_triangle());
1008 /// assert_eq!(mat.m12, Complex::new(99999.99999, 88888.88888)); // This was untouched.
1009 #[inline]
1010 pub fn hegerc<D2: Dim, D3: Dim, SB, SC>(
1011 &mut self,
1012 alpha: T,
1013 x: &Vector<T, D2, SB>,
1014 y: &Vector<T, D3, SC>,
1015 beta: T,
1016 ) where
1017 T: SimdComplexField,
1018 SB: Storage<T, D2>,
1019 SC: Storage<T, D3>,
1020 ShapeConstraint: DimEq<R1, D2> + DimEq<C1, D3>,
1021 {
1022 self.xxgerx(alpha, x, y, beta, SimdComplexField::simd_conjugate)
1023 }
1024}
1025
1026impl<T, D1: Dim, S: StorageMut<T, D1, D1>> SquareMatrix<T, D1, S>
1027where
1028 T: Scalar + Zero + One + ClosedAdd + ClosedMul,
1029{
1030 /// Computes the quadratic form `self = alpha * lhs * mid * lhs.transpose() + beta * self`.
1031 ///
1032 /// This uses the provided workspace `work` to avoid allocations for intermediate results.
1033 ///
1034 /// # Examples:
1035 ///
1036 /// ```
1037 /// # #[macro_use] extern crate approx;
1038 /// # use nalgebra::{DMatrix, DVector};
1039 /// // Note that all those would also work with statically-sized matrices.
1040 /// // We use DMatrix/DVector since that's the only case where pre-allocating the
1041 /// // workspace is actually useful (assuming the same workspace is re-used for
1042 /// // several computations) because it avoids repeated dynamic allocations.
1043 /// let mut mat = DMatrix::identity(2, 2);
1044 /// let lhs = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0,
1045 /// 4.0, 5.0, 6.0]);
1046 /// let mid = DMatrix::from_row_slice(3, 3, &[0.1, 0.2, 0.3,
1047 /// 0.5, 0.6, 0.7,
1048 /// 0.9, 1.0, 1.1]);
1049 /// // The random shows that values on the workspace do not
1050 /// // matter as they will be overwritten.
1051 /// let mut workspace = DVector::new_random(2);
1052 /// let expected = &lhs * &mid * lhs.transpose() * 10.0 + &mat * 5.0;
1053 ///
1054 /// mat.quadform_tr_with_workspace(&mut workspace, 10.0, &lhs, &mid, 5.0);
1055 /// assert_relative_eq!(mat, expected);
1056 pub fn quadform_tr_with_workspace<D2, S2, R3, C3, S3, D4, S4>(
1057 &mut self,
1058 work: &mut Vector<T, D2, S2>,
1059 alpha: T,
1060 lhs: &Matrix<T, R3, C3, S3>,
1061 mid: &SquareMatrix<T, D4, S4>,
1062 beta: T,
1063 ) where
1064 D2: Dim,
1065 R3: Dim,
1066 C3: Dim,
1067 D4: Dim,
1068 S2: StorageMut<T, D2>,
1069 S3: Storage<T, R3, C3>,
1070 S4: Storage<T, D4, D4>,
1071 ShapeConstraint: DimEq<D1, D2> + DimEq<D1, R3> + DimEq<D2, R3> + DimEq<C3, D4>,
1072 {
1073 work.gemv(T::one(), lhs, &mid.column(0), T::zero());
1074 self.ger(alpha.clone(), work, &lhs.column(0), beta);
1075
1076 for j in 1..mid.ncols() {
1077 work.gemv(T::one(), lhs, &mid.column(j), T::zero());
1078 self.ger(alpha.clone(), work, &lhs.column(j), T::one());
1079 }
1080 }
1081
1082 /// Computes the quadratic form `self = alpha * lhs * mid * lhs.transpose() + beta * self`.
1083 ///
1084 /// This allocates a workspace vector of dimension D1 for intermediate results.
1085 /// If `D1` is a type-level integer, then the allocation is performed on the stack.
1086 /// Use `.quadform_tr_with_workspace(...)` instead to avoid allocations.
1087 ///
1088 /// # Examples:
1089 ///
1090 /// ```
1091 /// # #[macro_use] extern crate approx;
1092 /// # use nalgebra::{Matrix2, Matrix3, Matrix2x3, Vector2};
1093 /// let mut mat = Matrix2::identity();
1094 /// let lhs = Matrix2x3::new(1.0, 2.0, 3.0,
1095 /// 4.0, 5.0, 6.0);
1096 /// let mid = Matrix3::new(0.1, 0.2, 0.3,
1097 /// 0.5, 0.6, 0.7,
1098 /// 0.9, 1.0, 1.1);
1099 /// let expected = lhs * mid * lhs.transpose() * 10.0 + mat * 5.0;
1100 ///
1101 /// mat.quadform_tr(10.0, &lhs, &mid, 5.0);
1102 /// assert_relative_eq!(mat, expected);
1103 pub fn quadform_tr<R3, C3, S3, D4, S4>(
1104 &mut self,
1105 alpha: T,
1106 lhs: &Matrix<T, R3, C3, S3>,
1107 mid: &SquareMatrix<T, D4, S4>,
1108 beta: T,
1109 ) where
1110 R3: Dim,
1111 C3: Dim,
1112 D4: Dim,
1113 S3: Storage<T, R3, C3>,
1114 S4: Storage<T, D4, D4>,
1115 ShapeConstraint: DimEq<D1, D1> + DimEq<D1, R3> + DimEq<C3, D4>,
1116 DefaultAllocator: Allocator<T, D1>,
1117 {
1118 // TODO: would it be useful to avoid the zero-initialization of the workspace data?
1119 let mut work = Matrix::zeros_generic(self.shape_generic().0, Const::<1>);
1120 self.quadform_tr_with_workspace(&mut work, alpha, lhs, mid, beta)
1121 }
1122
1123 /// Computes the quadratic form `self = alpha * rhs.transpose() * mid * rhs + beta * self`.
1124 ///
1125 /// This uses the provided workspace `work` to avoid allocations for intermediate results.
1126 ///
1127 /// ```
1128 /// # #[macro_use] extern crate approx;
1129 /// # use nalgebra::{DMatrix, DVector};
1130 /// // Note that all those would also work with statically-sized matrices.
1131 /// // We use DMatrix/DVector since that's the only case where pre-allocating the
1132 /// // workspace is actually useful (assuming the same workspace is re-used for
1133 /// // several computations) because it avoids repeated dynamic allocations.
1134 /// let mut mat = DMatrix::identity(2, 2);
1135 /// let rhs = DMatrix::from_row_slice(3, 2, &[1.0, 2.0,
1136 /// 3.0, 4.0,
1137 /// 5.0, 6.0]);
1138 /// let mid = DMatrix::from_row_slice(3, 3, &[0.1, 0.2, 0.3,
1139 /// 0.5, 0.6, 0.7,
1140 /// 0.9, 1.0, 1.1]);
1141 /// // The random shows that values on the workspace do not
1142 /// // matter as they will be overwritten.
1143 /// let mut workspace = DVector::new_random(3);
1144 /// let expected = rhs.transpose() * &mid * &rhs * 10.0 + &mat * 5.0;
1145 ///
1146 /// mat.quadform_with_workspace(&mut workspace, 10.0, &mid, &rhs, 5.0);
1147 /// assert_relative_eq!(mat, expected);
1148 pub fn quadform_with_workspace<D2, S2, D3, S3, R4, C4, S4>(
1149 &mut self,
1150 work: &mut Vector<T, D2, S2>,
1151 alpha: T,
1152 mid: &SquareMatrix<T, D3, S3>,
1153 rhs: &Matrix<T, R4, C4, S4>,
1154 beta: T,
1155 ) where
1156 D2: Dim,
1157 D3: Dim,
1158 R4: Dim,
1159 C4: Dim,
1160 S2: StorageMut<T, D2>,
1161 S3: Storage<T, D3, D3>,
1162 S4: Storage<T, R4, C4>,
1163 ShapeConstraint:
1164 DimEq<D3, R4> + DimEq<D1, C4> + DimEq<D2, D3> + AreMultipliable<C4, R4, D2, U1>,
1165 {
1166 work.gemv(T::one(), mid, &rhs.column(0), T::zero());
1167 self.column_mut(0)
1168 .gemv_tr(alpha.clone(), rhs, work, beta.clone());
1169
1170 for j in 1..rhs.ncols() {
1171 work.gemv(T::one(), mid, &rhs.column(j), T::zero());
1172 self.column_mut(j)
1173 .gemv_tr(alpha.clone(), rhs, work, beta.clone());
1174 }
1175 }
1176
1177 /// Computes the quadratic form `self = alpha * rhs.transpose() * mid * rhs + beta * self`.
1178 ///
1179 /// This allocates a workspace vector of dimension D2 for intermediate results.
1180 /// If `D2` is a type-level integer, then the allocation is performed on the stack.
1181 /// Use `.quadform_with_workspace(...)` instead to avoid allocations.
1182 ///
1183 /// ```
1184 /// # #[macro_use] extern crate approx;
1185 /// # use nalgebra::{Matrix2, Matrix3x2, Matrix3};
1186 /// let mut mat = Matrix2::identity();
1187 /// let rhs = Matrix3x2::new(1.0, 2.0,
1188 /// 3.0, 4.0,
1189 /// 5.0, 6.0);
1190 /// let mid = Matrix3::new(0.1, 0.2, 0.3,
1191 /// 0.5, 0.6, 0.7,
1192 /// 0.9, 1.0, 1.1);
1193 /// let expected = rhs.transpose() * mid * rhs * 10.0 + mat * 5.0;
1194 ///
1195 /// mat.quadform(10.0, &mid, &rhs, 5.0);
1196 /// assert_relative_eq!(mat, expected);
1197 pub fn quadform<D2, S2, R3, C3, S3>(
1198 &mut self,
1199 alpha: T,
1200 mid: &SquareMatrix<T, D2, S2>,
1201 rhs: &Matrix<T, R3, C3, S3>,
1202 beta: T,
1203 ) where
1204 D2: Dim,
1205 R3: Dim,
1206 C3: Dim,
1207 S2: Storage<T, D2, D2>,
1208 S3: Storage<T, R3, C3>,
1209 ShapeConstraint: DimEq<D2, R3> + DimEq<D1, C3> + AreMultipliable<C3, R3, D2, U1>,
1210 DefaultAllocator: Allocator<T, D2>,
1211 {
1212 // TODO: would it be useful to avoid the zero-initialization of the workspace data?
1213 let mut work = Vector::zeros_generic(mid.shape_generic().0, Const::<1>);
1214 self.quadform_with_workspace(&mut work, alpha, mid, rhs, beta)
1215 }
1216}