1#![allow(clippy::suspicious_operation_groupings)]
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use approx::AbsDiffEq;
6use num_complex::Complex as NumComplex;
7use simba::scalar::{ComplexField, RealField};
8use std::cmp;
9
10use crate::allocator::Allocator;
11use crate::base::dimension::{Const, Dim, DimDiff, DimSub, Dynamic, U1, U2};
12use crate::base::storage::Storage;
13use crate::base::{DefaultAllocator, OMatrix, OVector, SquareMatrix, Unit, Vector2, Vector3};
14
15use crate::geometry::Reflection;
16use crate::linalg::givens::GivensRotation;
17use crate::linalg::householder;
18use crate::linalg::Hessenberg;
19use crate::{Matrix, UninitVector};
20use std::mem::MaybeUninit;
21
22#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
26#[cfg_attr(
27 feature = "serde-serialize-no-std",
28 serde(bound(serialize = "DefaultAllocator: Allocator<T, D, D>,
29 OMatrix<T, D, D>: Serialize"))
30)]
31#[cfg_attr(
32 feature = "serde-serialize-no-std",
33 serde(bound(deserialize = "DefaultAllocator: Allocator<T, D, D>,
34 OMatrix<T, D, D>: Deserialize<'de>"))
35)]
36#[derive(Clone, Debug)]
37pub struct Schur<T: ComplexField, D: Dim>
38where
39 DefaultAllocator: Allocator<T, D, D>,
40{
41 q: OMatrix<T, D, D>,
42 t: OMatrix<T, D, D>,
43}
44
45impl<T: ComplexField, D: Dim> Copy for Schur<T, D>
46where
47 DefaultAllocator: Allocator<T, D, D>,
48 OMatrix<T, D, D>: Copy,
49{
50}
51
52impl<T: ComplexField, D: Dim> Schur<T, D>
53where
54 D: DimSub<U1>, DefaultAllocator: Allocator<T, D, DimDiff<D, U1>>
56 + Allocator<T, DimDiff<D, U1>>
57 + Allocator<T, D, D>
58 + Allocator<T, D>,
59{
60 pub fn new(m: OMatrix<T, D, D>) -> Self {
62 Self::try_new(m, T::RealField::default_epsilon(), 0).unwrap()
63 }
64
65 pub fn try_new(m: OMatrix<T, D, D>, eps: T::RealField, max_niter: usize) -> Option<Self> {
77 let mut work = Matrix::zeros_generic(m.shape_generic().0, Const::<1>);
78
79 Self::do_decompose(m, &mut work, eps, max_niter, true)
80 .map(|(q, t)| Schur { q: q.unwrap(), t })
81 }
82
83 fn do_decompose(
84 mut m: OMatrix<T, D, D>,
85 work: &mut OVector<T, D>,
86 eps: T::RealField,
87 max_niter: usize,
88 compute_q: bool,
89 ) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)> {
90 assert!(
91 m.is_square(),
92 "Unable to compute the eigenvectors and eigenvalues of a non-square matrix."
93 );
94
95 let dim = m.shape_generic().0;
96
97 if dim.value() == 0 {
99 let vecs = Some(OMatrix::from_element_generic(dim, dim, T::zero()));
100 let vals = OMatrix::from_element_generic(dim, dim, T::zero());
101 return Some((vecs, vals));
102 } else if dim.value() == 1 {
103 if compute_q {
104 let q = OMatrix::from_element_generic(dim, dim, T::one());
105 return Some((Some(q), m));
106 } else {
107 return Some((None, m));
108 }
109 } else if dim.value() == 2 {
110 return decompose_2x2(m, compute_q);
111 }
112
113 let amax_m = m.camax();
114 m.unscale_mut(amax_m.clone());
115
116 let hess = Hessenberg::new_with_workspace(m, work);
117 let mut q;
118 let mut t;
119
120 if compute_q {
121 let (vecs, vals) = hess.unpack();
124 q = Some(vecs);
125 t = vals;
126 } else {
127 q = None;
128 t = hess.unpack_h()
129 }
130
131 let mut niter = 0;
133 let (mut start, mut end) = Self::delimit_subproblem(&mut t, eps.clone(), dim.value() - 1);
134
135 while end != start {
136 let subdim = end - start + 1;
137
138 if subdim > 2 {
139 let m = end - 1;
140 let n = end;
141
142 let h11 = t[(start, start)].clone();
143 let h12 = t[(start, start + 1)].clone();
144 let h21 = t[(start + 1, start)].clone();
145 let h22 = t[(start + 1, start + 1)].clone();
146 let h32 = t[(start + 2, start + 1)].clone();
147
148 let hnn = t[(n, n)].clone();
149 let hmm = t[(m, m)].clone();
150 let hnm = t[(n, m)].clone();
151 let hmn = t[(m, n)].clone();
152
153 let tra = hnn.clone() + hmm.clone();
154 let det = hnn * hmm - hnm * hmn;
155
156 let mut axis = Vector3::new(
157 h11.clone() * h11.clone() + h12 * h21.clone() - tra.clone() * h11.clone() + det,
158 h21.clone() * (h11 + h22 - tra),
159 h21 * h32,
160 );
161
162 for k in start..n - 1 {
163 let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
164
165 if not_zero {
166 if k > start {
167 t[(k, k - 1)] = norm;
168 t[(k + 1, k - 1)] = T::zero();
169 t[(k + 2, k - 1)] = T::zero();
170 }
171
172 let refl = Reflection::new(Unit::new_unchecked(axis.clone()), T::zero());
173
174 {
175 let krows = cmp::min(k + 4, end + 1);
176 let mut work = work.rows_mut(0, krows);
177 refl.reflect(&mut t.generic_slice_mut(
178 (k, k),
179 (Const::<3>, Dynamic::new(dim.value() - k)),
180 ));
181 refl.reflect_rows(
182 &mut t.generic_slice_mut((0, k), (Dynamic::new(krows), Const::<3>)),
183 &mut work,
184 );
185 }
186
187 if let Some(ref mut q) = q {
188 refl.reflect_rows(
189 &mut q.generic_slice_mut((0, k), (dim, Const::<3>)),
190 work,
191 );
192 }
193 }
194
195 axis.x = t[(k + 1, k)].clone();
196 axis.y = t[(k + 2, k)].clone();
197
198 if k < n - 2 {
199 axis.z = t[(k + 3, k)].clone();
200 }
201 }
202
203 let mut axis = Vector2::new(axis.x.clone(), axis.y.clone());
204 let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
205
206 if not_zero {
207 let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
208
209 t[(m, m - 1)] = norm;
210 t[(n, m - 1)] = T::zero();
211
212 {
213 let mut work = work.rows_mut(0, end + 1);
214 refl.reflect(&mut t.generic_slice_mut(
215 (m, m),
216 (Const::<2>, Dynamic::new(dim.value() - m)),
217 ));
218 refl.reflect_rows(
219 &mut t.generic_slice_mut((0, m), (Dynamic::new(end + 1), Const::<2>)),
220 &mut work,
221 );
222 }
223
224 if let Some(ref mut q) = q {
225 refl.reflect_rows(
226 &mut q.generic_slice_mut((0, m), (dim, Const::<2>)),
227 work,
228 );
229 }
230 }
231 } else {
232 if let Some(rot) = compute_2x2_basis(&t.fixed_slice::<2, 2>(start, start)) {
234 let inv_rot = rot.inverse();
235 inv_rot.rotate(&mut t.generic_slice_mut(
236 (start, start),
237 (Const::<2>, Dynamic::new(dim.value() - start)),
238 ));
239 rot.rotate_rows(
240 &mut t.generic_slice_mut((0, start), (Dynamic::new(end + 1), Const::<2>)),
241 );
242 t[(end, start)] = T::zero();
243
244 if let Some(ref mut q) = q {
245 rot.rotate_rows(&mut q.generic_slice_mut((0, start), (dim, Const::<2>)));
246 }
247 }
248
249 if end > 2 {
251 end -= 2;
252 } else {
253 break;
254 }
255 }
256
257 let sub = Self::delimit_subproblem(&mut t, eps.clone(), end);
258
259 start = sub.0;
260 end = sub.1;
261
262 niter += 1;
263 if niter == max_niter {
264 return None;
265 }
266 }
267
268 t.scale_mut(amax_m);
269
270 Some((q, t))
271 }
272
273 fn do_eigenvalues(t: &OMatrix<T, D, D>, out: &mut OVector<T, D>) -> bool {
275 let dim = t.nrows();
276 let mut m = 0;
277
278 while m < dim - 1 {
279 let n = m + 1;
280
281 if t[(n, m)].is_zero() {
282 out[m] = t[(m, m)].clone();
283 m += 1;
284 } else {
285 return false;
287 }
288 }
289
290 if m == dim - 1 {
291 out[m] = t[(m, m)].clone();
292 }
293
294 true
295 }
296
297 fn do_complex_eigenvalues(t: &OMatrix<T, D, D>, out: &mut UninitVector<NumComplex<T>, D>)
299 where
300 T: RealField,
301 DefaultAllocator: Allocator<NumComplex<T>, D>,
302 {
303 let dim = t.nrows();
304 let mut m = 0;
305
306 while m < dim - 1 {
307 let n = m + 1;
308
309 if t[(n, m)].is_zero() {
310 out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
311 m += 1;
312 } else {
313 let hmm = t[(m, m)].clone();
315 let hnm = t[(n, m)].clone();
316 let hmn = t[(m, n)].clone();
317 let hnn = t[(n, n)].clone();
318
319 let val = (hmm.clone() - hnn.clone()) * crate::convert(0.5);
321 let discr = hnm * hmn + val.clone() * val;
322
323 let sqrt_discr = NumComplex::new(T::zero(), (-discr).sqrt());
326
327 let half_tra = (hnn + hmm) * crate::convert(0.5);
328 out[m] = MaybeUninit::new(
329 NumComplex::new(half_tra.clone(), T::zero()) + sqrt_discr.clone(),
330 );
331 out[m + 1] =
332 MaybeUninit::new(NumComplex::new(half_tra, T::zero()) - sqrt_discr.clone());
333
334 m += 2;
335 }
336 }
337
338 if m == dim - 1 {
339 out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
340 }
341 }
342
343 fn delimit_subproblem(t: &mut OMatrix<T, D, D>, eps: T::RealField, end: usize) -> (usize, usize)
344 where
345 D: DimSub<U1>,
346 DefaultAllocator: Allocator<T, DimDiff<D, U1>>,
347 {
348 let mut n = end;
349
350 while n > 0 {
351 let m = n - 1;
352
353 if t[(n, m)].clone().norm1()
354 <= eps.clone() * (t[(n, n)].clone().norm1() + t[(m, m)].clone().norm1())
355 {
356 t[(n, m)] = T::zero();
357 } else {
358 break;
359 }
360
361 n -= 1;
362 }
363
364 if n == 0 {
365 return (0, 0);
366 }
367
368 let mut new_start = n - 1;
369 while new_start > 0 {
370 let m = new_start - 1;
371
372 let off_diag = t[(new_start, m)].clone();
373 if off_diag.is_zero()
374 || off_diag.norm1()
375 <= eps.clone()
376 * (t[(new_start, new_start)].clone().norm1() + t[(m, m)].clone().norm1())
377 {
378 t[(new_start, m)] = T::zero();
379 break;
380 }
381
382 new_start -= 1;
383 }
384
385 (new_start, n)
386 }
387
388 pub fn unpack(self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
391 (self.q, self.t)
392 }
393
394 #[must_use]
398 pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
399 let mut out = Matrix::zeros_generic(self.t.shape_generic().0, Const::<1>);
400 if Self::do_eigenvalues(&self.t, &mut out) {
401 Some(out)
402 } else {
403 None
404 }
405 }
406
407 #[must_use]
409 pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
410 where
411 T: RealField,
412 DefaultAllocator: Allocator<NumComplex<T>, D>,
413 {
414 let mut out = Matrix::uninit(self.t.shape_generic().0, Const::<1>);
415 Self::do_complex_eigenvalues(&self.t, &mut out);
416 unsafe { out.assume_init() }
418 }
419}
420
421fn decompose_2x2<T: ComplexField, D: Dim>(
422 mut m: OMatrix<T, D, D>,
423 compute_q: bool,
424) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)>
425where
426 DefaultAllocator: Allocator<T, D, D>,
427{
428 let dim = m.shape_generic().0;
429 let mut q = None;
430 match compute_2x2_basis(&m.fixed_slice::<2, 2>(0, 0)) {
431 Some(rot) => {
432 let mut m = m.fixed_slice_mut::<2, 2>(0, 0);
433 let inv_rot = rot.inverse();
434 inv_rot.rotate(&mut m);
435 rot.rotate_rows(&mut m);
436 m[(1, 0)] = T::zero();
437
438 if compute_q {
439 let c = T::from_real(rot.c());
442 q = Some(OMatrix::from_column_slice_generic(
443 dim,
444 dim,
445 &[c.clone(), rot.s(), -rot.s().conjugate(), c],
446 ));
447 }
448 }
449 None => {
450 if compute_q {
451 q = Some(OMatrix::identity_generic(dim, dim));
452 }
453 }
454 };
455
456 Some((q, m))
457}
458
459fn compute_2x2_eigvals<T: ComplexField, S: Storage<T, U2, U2>>(
460 m: &SquareMatrix<T, U2, S>,
461) -> Option<(T, T)> {
462 let h00 = m[(0, 0)].clone();
464 let h10 = m[(1, 0)].clone();
465 let h01 = m[(0, 1)].clone();
466 let h11 = m[(1, 1)].clone();
467
468 let val = (h00.clone() - h11.clone()) * crate::convert(0.5);
472 let discr = h10 * h01 + val.clone() * val;
473
474 discr.try_sqrt().map(|sqrt_discr| {
475 let half_tra = (h00 + h11) * crate::convert(0.5);
476 (half_tra.clone() + sqrt_discr.clone(), half_tra - sqrt_discr)
477 })
478}
479
480fn compute_2x2_basis<T: ComplexField, S: Storage<T, U2, U2>>(
486 m: &SquareMatrix<T, U2, S>,
487) -> Option<GivensRotation<T>> {
488 let h10 = m[(1, 0)].clone();
489
490 if h10.is_zero() {
491 return None;
492 }
493
494 if let Some((eigval1, eigval2)) = compute_2x2_eigvals(m) {
495 let x1 = eigval1 - m[(1, 1)].clone();
496 let x2 = eigval2 - m[(1, 1)].clone();
497
498 if x1.clone().norm1() > x2.clone().norm1() {
502 Some(GivensRotation::new(x1, h10).0)
503 } else {
504 Some(GivensRotation::new(x2, h10).0)
505 }
506 } else {
507 None
508 }
509}
510
511impl<T: ComplexField, D: Dim, S: Storage<T, D, D>> SquareMatrix<T, D, S>
512where
513 D: DimSub<U1>, DefaultAllocator: Allocator<T, D, DimDiff<D, U1>>
515 + Allocator<T, DimDiff<D, U1>>
516 + Allocator<T, D, D>
517 + Allocator<T, D>,
518{
519 #[must_use]
521 pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
522 assert!(
523 self.is_square(),
524 "Unable to compute eigenvalues of a non-square matrix."
525 );
526
527 let mut work = Matrix::zeros_generic(self.shape_generic().0, Const::<1>);
528
529 if self.nrows() == 2 {
531 let me = self.fixed_slice::<2, 2>(0, 0);
534 return match compute_2x2_eigvals(&me) {
535 Some((a, b)) => {
536 work[0] = a;
537 work[1] = b;
538 Some(work)
539 }
540 None => None,
541 };
542 }
543
544 let schur = Schur::do_decompose(
546 self.clone_owned(),
547 &mut work,
548 T::RealField::default_epsilon(),
549 0,
550 false,
551 )
552 .unwrap();
553
554 if Schur::do_eigenvalues(&schur.1, &mut work) {
555 Some(work)
556 } else {
557 None
558 }
559 }
560
561 #[must_use]
563 pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
564 where
566 T: RealField,
567 DefaultAllocator: Allocator<NumComplex<T>, D>,
568 {
569 let dim = self.shape_generic().0;
570 let mut work = Matrix::zeros_generic(dim, Const::<1>);
571
572 let schur = Schur::do_decompose(
573 self.clone_owned(),
574 &mut work,
575 T::default_epsilon(),
576 0,
577 false,
578 )
579 .unwrap();
580 let mut eig = Matrix::uninit(dim, Const::<1>);
581 Schur::do_complex_eigenvalues(&schur.1, &mut eig);
582 unsafe { eig.assume_init() }
584 }
585}