nalgebra/linalg/
schur.rs

1#![allow(clippy::suspicious_operation_groupings)]
2#[cfg(feature = "serde-serialize-no-std")]
3use serde::{Deserialize, Serialize};
4
5use approx::AbsDiffEq;
6use num_complex::Complex as NumComplex;
7use simba::scalar::{ComplexField, RealField};
8use std::cmp;
9
10use crate::allocator::Allocator;
11use crate::base::dimension::{Const, Dim, DimDiff, DimSub, Dynamic, U1, U2};
12use crate::base::storage::Storage;
13use crate::base::{DefaultAllocator, OMatrix, OVector, SquareMatrix, Unit, Vector2, Vector3};
14
15use crate::geometry::Reflection;
16use crate::linalg::givens::GivensRotation;
17use crate::linalg::householder;
18use crate::linalg::Hessenberg;
19use crate::{Matrix, UninitVector};
20use std::mem::MaybeUninit;
21
22/// Schur decomposition of a square matrix.
23///
24/// If this is a real matrix, this will be a `RealField` Schur decomposition.
25#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
26#[cfg_attr(
27    feature = "serde-serialize-no-std",
28    serde(bound(serialize = "DefaultAllocator: Allocator<T, D, D>,
29         OMatrix<T, D, D>: Serialize"))
30)]
31#[cfg_attr(
32    feature = "serde-serialize-no-std",
33    serde(bound(deserialize = "DefaultAllocator: Allocator<T, D, D>,
34         OMatrix<T, D, D>: Deserialize<'de>"))
35)]
36#[derive(Clone, Debug)]
37pub struct Schur<T: ComplexField, D: Dim>
38where
39    DefaultAllocator: Allocator<T, D, D>,
40{
41    q: OMatrix<T, D, D>,
42    t: OMatrix<T, D, D>,
43}
44
45impl<T: ComplexField, D: Dim> Copy for Schur<T, D>
46where
47    DefaultAllocator: Allocator<T, D, D>,
48    OMatrix<T, D, D>: Copy,
49{
50}
51
52impl<T: ComplexField, D: Dim> Schur<T, D>
53where
54    D: DimSub<U1>, // For Hessenberg.
55    DefaultAllocator: Allocator<T, D, DimDiff<D, U1>>
56        + Allocator<T, DimDiff<D, U1>>
57        + Allocator<T, D, D>
58        + Allocator<T, D>,
59{
60    /// Computes the Schur decomposition of a square matrix.
61    pub fn new(m: OMatrix<T, D, D>) -> Self {
62        Self::try_new(m, T::RealField::default_epsilon(), 0).unwrap()
63    }
64
65    /// Attempts to compute the Schur decomposition of a square matrix.
66    ///
67    /// If only eigenvalues are needed, it is more efficient to call the matrix method
68    /// `.eigenvalues()` instead.
69    ///
70    /// # Arguments
71    ///
72    /// * `eps`       − tolerance used to determine when a value converged to 0.
73    /// * `max_niter` − maximum total number of iterations performed by the algorithm. If this
74    /// number of iteration is exceeded, `None` is returned. If `niter == 0`, then the algorithm
75    /// continues indefinitely until convergence.
76    pub fn try_new(m: OMatrix<T, D, D>, eps: T::RealField, max_niter: usize) -> Option<Self> {
77        let mut work = Matrix::zeros_generic(m.shape_generic().0, Const::<1>);
78
79        Self::do_decompose(m, &mut work, eps, max_niter, true)
80            .map(|(q, t)| Schur { q: q.unwrap(), t })
81    }
82
83    fn do_decompose(
84        mut m: OMatrix<T, D, D>,
85        work: &mut OVector<T, D>,
86        eps: T::RealField,
87        max_niter: usize,
88        compute_q: bool,
89    ) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)> {
90        assert!(
91            m.is_square(),
92            "Unable to compute the eigenvectors and eigenvalues of a non-square matrix."
93        );
94
95        let dim = m.shape_generic().0;
96
97        // Specialization would make this easier.
98        if dim.value() == 0 {
99            let vecs = Some(OMatrix::from_element_generic(dim, dim, T::zero()));
100            let vals = OMatrix::from_element_generic(dim, dim, T::zero());
101            return Some((vecs, vals));
102        } else if dim.value() == 1 {
103            if compute_q {
104                let q = OMatrix::from_element_generic(dim, dim, T::one());
105                return Some((Some(q), m));
106            } else {
107                return Some((None, m));
108            }
109        } else if dim.value() == 2 {
110            return decompose_2x2(m, compute_q);
111        }
112
113        let amax_m = m.camax();
114        m.unscale_mut(amax_m.clone());
115
116        let hess = Hessenberg::new_with_workspace(m, work);
117        let mut q;
118        let mut t;
119
120        if compute_q {
121            // TODO: could we work without unpacking? Using only the internal representation of
122            // hessenberg decomposition.
123            let (vecs, vals) = hess.unpack();
124            q = Some(vecs);
125            t = vals;
126        } else {
127            q = None;
128            t = hess.unpack_h()
129        }
130
131        // Implicit double-shift QR method.
132        let mut niter = 0;
133        let (mut start, mut end) = Self::delimit_subproblem(&mut t, eps.clone(), dim.value() - 1);
134
135        while end != start {
136            let subdim = end - start + 1;
137
138            if subdim > 2 {
139                let m = end - 1;
140                let n = end;
141
142                let h11 = t[(start, start)].clone();
143                let h12 = t[(start, start + 1)].clone();
144                let h21 = t[(start + 1, start)].clone();
145                let h22 = t[(start + 1, start + 1)].clone();
146                let h32 = t[(start + 2, start + 1)].clone();
147
148                let hnn = t[(n, n)].clone();
149                let hmm = t[(m, m)].clone();
150                let hnm = t[(n, m)].clone();
151                let hmn = t[(m, n)].clone();
152
153                let tra = hnn.clone() + hmm.clone();
154                let det = hnn * hmm - hnm * hmn;
155
156                let mut axis = Vector3::new(
157                    h11.clone() * h11.clone() + h12 * h21.clone() - tra.clone() * h11.clone() + det,
158                    h21.clone() * (h11 + h22 - tra),
159                    h21 * h32,
160                );
161
162                for k in start..n - 1 {
163                    let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
164
165                    if not_zero {
166                        if k > start {
167                            t[(k, k - 1)] = norm;
168                            t[(k + 1, k - 1)] = T::zero();
169                            t[(k + 2, k - 1)] = T::zero();
170                        }
171
172                        let refl = Reflection::new(Unit::new_unchecked(axis.clone()), T::zero());
173
174                        {
175                            let krows = cmp::min(k + 4, end + 1);
176                            let mut work = work.rows_mut(0, krows);
177                            refl.reflect(&mut t.generic_slice_mut(
178                                (k, k),
179                                (Const::<3>, Dynamic::new(dim.value() - k)),
180                            ));
181                            refl.reflect_rows(
182                                &mut t.generic_slice_mut((0, k), (Dynamic::new(krows), Const::<3>)),
183                                &mut work,
184                            );
185                        }
186
187                        if let Some(ref mut q) = q {
188                            refl.reflect_rows(
189                                &mut q.generic_slice_mut((0, k), (dim, Const::<3>)),
190                                work,
191                            );
192                        }
193                    }
194
195                    axis.x = t[(k + 1, k)].clone();
196                    axis.y = t[(k + 2, k)].clone();
197
198                    if k < n - 2 {
199                        axis.z = t[(k + 3, k)].clone();
200                    }
201                }
202
203                let mut axis = Vector2::new(axis.x.clone(), axis.y.clone());
204                let (norm, not_zero) = householder::reflection_axis_mut(&mut axis);
205
206                if not_zero {
207                    let refl = Reflection::new(Unit::new_unchecked(axis), T::zero());
208
209                    t[(m, m - 1)] = norm;
210                    t[(n, m - 1)] = T::zero();
211
212                    {
213                        let mut work = work.rows_mut(0, end + 1);
214                        refl.reflect(&mut t.generic_slice_mut(
215                            (m, m),
216                            (Const::<2>, Dynamic::new(dim.value() - m)),
217                        ));
218                        refl.reflect_rows(
219                            &mut t.generic_slice_mut((0, m), (Dynamic::new(end + 1), Const::<2>)),
220                            &mut work,
221                        );
222                    }
223
224                    if let Some(ref mut q) = q {
225                        refl.reflect_rows(
226                            &mut q.generic_slice_mut((0, m), (dim, Const::<2>)),
227                            work,
228                        );
229                    }
230                }
231            } else {
232                // Decouple the 2x2 block if it has real eigenvalues.
233                if let Some(rot) = compute_2x2_basis(&t.fixed_slice::<2, 2>(start, start)) {
234                    let inv_rot = rot.inverse();
235                    inv_rot.rotate(&mut t.generic_slice_mut(
236                        (start, start),
237                        (Const::<2>, Dynamic::new(dim.value() - start)),
238                    ));
239                    rot.rotate_rows(
240                        &mut t.generic_slice_mut((0, start), (Dynamic::new(end + 1), Const::<2>)),
241                    );
242                    t[(end, start)] = T::zero();
243
244                    if let Some(ref mut q) = q {
245                        rot.rotate_rows(&mut q.generic_slice_mut((0, start), (dim, Const::<2>)));
246                    }
247                }
248
249                // Check if we reached the beginning of the matrix.
250                if end > 2 {
251                    end -= 2;
252                } else {
253                    break;
254                }
255            }
256
257            let sub = Self::delimit_subproblem(&mut t, eps.clone(), end);
258
259            start = sub.0;
260            end = sub.1;
261
262            niter += 1;
263            if niter == max_niter {
264                return None;
265            }
266        }
267
268        t.scale_mut(amax_m);
269
270        Some((q, t))
271    }
272
273    /// Computes the eigenvalues of the decomposed matrix.
274    fn do_eigenvalues(t: &OMatrix<T, D, D>, out: &mut OVector<T, D>) -> bool {
275        let dim = t.nrows();
276        let mut m = 0;
277
278        while m < dim - 1 {
279            let n = m + 1;
280
281            if t[(n, m)].is_zero() {
282                out[m] = t[(m, m)].clone();
283                m += 1;
284            } else {
285                // Complex eigenvalue.
286                return false;
287            }
288        }
289
290        if m == dim - 1 {
291            out[m] = t[(m, m)].clone();
292        }
293
294        true
295    }
296
297    /// Computes the complex eigenvalues of the decomposed matrix.
298    fn do_complex_eigenvalues(t: &OMatrix<T, D, D>, out: &mut UninitVector<NumComplex<T>, D>)
299    where
300        T: RealField,
301        DefaultAllocator: Allocator<NumComplex<T>, D>,
302    {
303        let dim = t.nrows();
304        let mut m = 0;
305
306        while m < dim - 1 {
307            let n = m + 1;
308
309            if t[(n, m)].is_zero() {
310                out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
311                m += 1;
312            } else {
313                // Solve the 2x2 eigenvalue subproblem.
314                let hmm = t[(m, m)].clone();
315                let hnm = t[(n, m)].clone();
316                let hmn = t[(m, n)].clone();
317                let hnn = t[(n, n)].clone();
318
319                // NOTE: use the same algorithm as in compute_2x2_eigvals.
320                let val = (hmm.clone() - hnn.clone()) * crate::convert(0.5);
321                let discr = hnm * hmn + val.clone() * val;
322
323                // All 2x2 blocks have negative discriminant because we already decoupled those
324                // with positive eigenvalues.
325                let sqrt_discr = NumComplex::new(T::zero(), (-discr).sqrt());
326
327                let half_tra = (hnn + hmm) * crate::convert(0.5);
328                out[m] = MaybeUninit::new(
329                    NumComplex::new(half_tra.clone(), T::zero()) + sqrt_discr.clone(),
330                );
331                out[m + 1] =
332                    MaybeUninit::new(NumComplex::new(half_tra, T::zero()) - sqrt_discr.clone());
333
334                m += 2;
335            }
336        }
337
338        if m == dim - 1 {
339            out[m] = MaybeUninit::new(NumComplex::new(t[(m, m)].clone(), T::zero()));
340        }
341    }
342
343    fn delimit_subproblem(t: &mut OMatrix<T, D, D>, eps: T::RealField, end: usize) -> (usize, usize)
344    where
345        D: DimSub<U1>,
346        DefaultAllocator: Allocator<T, DimDiff<D, U1>>,
347    {
348        let mut n = end;
349
350        while n > 0 {
351            let m = n - 1;
352
353            if t[(n, m)].clone().norm1()
354                <= eps.clone() * (t[(n, n)].clone().norm1() + t[(m, m)].clone().norm1())
355            {
356                t[(n, m)] = T::zero();
357            } else {
358                break;
359            }
360
361            n -= 1;
362        }
363
364        if n == 0 {
365            return (0, 0);
366        }
367
368        let mut new_start = n - 1;
369        while new_start > 0 {
370            let m = new_start - 1;
371
372            let off_diag = t[(new_start, m)].clone();
373            if off_diag.is_zero()
374                || off_diag.norm1()
375                    <= eps.clone()
376                        * (t[(new_start, new_start)].clone().norm1() + t[(m, m)].clone().norm1())
377            {
378                t[(new_start, m)] = T::zero();
379                break;
380            }
381
382            new_start -= 1;
383        }
384
385        (new_start, n)
386    }
387
388    /// Retrieves the unitary matrix `Q` and the upper-quasitriangular matrix `T` such that the
389    /// decomposed matrix equals `Q * T * Q.transpose()`.
390    pub fn unpack(self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
391        (self.q, self.t)
392    }
393
394    /// Computes the real eigenvalues of the decomposed matrix.
395    ///
396    /// Return `None` if some eigenvalues are complex.
397    #[must_use]
398    pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
399        let mut out = Matrix::zeros_generic(self.t.shape_generic().0, Const::<1>);
400        if Self::do_eigenvalues(&self.t, &mut out) {
401            Some(out)
402        } else {
403            None
404        }
405    }
406
407    /// Computes the complex eigenvalues of the decomposed matrix.
408    #[must_use]
409    pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
410    where
411        T: RealField,
412        DefaultAllocator: Allocator<NumComplex<T>, D>,
413    {
414        let mut out = Matrix::uninit(self.t.shape_generic().0, Const::<1>);
415        Self::do_complex_eigenvalues(&self.t, &mut out);
416        // Safety: out has been fully initialized by do_complex_eigenvalues.
417        unsafe { out.assume_init() }
418    }
419}
420
421fn decompose_2x2<T: ComplexField, D: Dim>(
422    mut m: OMatrix<T, D, D>,
423    compute_q: bool,
424) -> Option<(Option<OMatrix<T, D, D>>, OMatrix<T, D, D>)>
425where
426    DefaultAllocator: Allocator<T, D, D>,
427{
428    let dim = m.shape_generic().0;
429    let mut q = None;
430    match compute_2x2_basis(&m.fixed_slice::<2, 2>(0, 0)) {
431        Some(rot) => {
432            let mut m = m.fixed_slice_mut::<2, 2>(0, 0);
433            let inv_rot = rot.inverse();
434            inv_rot.rotate(&mut m);
435            rot.rotate_rows(&mut m);
436            m[(1, 0)] = T::zero();
437
438            if compute_q {
439                // XXX: we have to build the matrix manually because
440                // rot.to_rotation_matrix().unwrap() causes an ICE.
441                let c = T::from_real(rot.c());
442                q = Some(OMatrix::from_column_slice_generic(
443                    dim,
444                    dim,
445                    &[c.clone(), rot.s(), -rot.s().conjugate(), c],
446                ));
447            }
448        }
449        None => {
450            if compute_q {
451                q = Some(OMatrix::identity_generic(dim, dim));
452            }
453        }
454    };
455
456    Some((q, m))
457}
458
459fn compute_2x2_eigvals<T: ComplexField, S: Storage<T, U2, U2>>(
460    m: &SquareMatrix<T, U2, S>,
461) -> Option<(T, T)> {
462    // Solve the 2x2 eigenvalue subproblem.
463    let h00 = m[(0, 0)].clone();
464    let h10 = m[(1, 0)].clone();
465    let h01 = m[(0, 1)].clone();
466    let h11 = m[(1, 1)].clone();
467
468    // NOTE: this discriminant computation is more stable than the
469    // one based on the trace and determinant: 0.25 * tra * tra - det
470    // because it ensures positiveness for symmetric matrices.
471    let val = (h00.clone() - h11.clone()) * crate::convert(0.5);
472    let discr = h10 * h01 + val.clone() * val;
473
474    discr.try_sqrt().map(|sqrt_discr| {
475        let half_tra = (h00 + h11) * crate::convert(0.5);
476        (half_tra.clone() + sqrt_discr.clone(), half_tra - sqrt_discr)
477    })
478}
479
480// Computes the 2x2 transformation that upper-triangulates a 2x2 matrix with real eigenvalues.
481/// Computes the singular vectors for a 2x2 matrix.
482///
483/// Returns `None` if the matrix has complex eigenvalues, or is upper-triangular. In both case,
484/// the basis is the identity.
485fn compute_2x2_basis<T: ComplexField, S: Storage<T, U2, U2>>(
486    m: &SquareMatrix<T, U2, S>,
487) -> Option<GivensRotation<T>> {
488    let h10 = m[(1, 0)].clone();
489
490    if h10.is_zero() {
491        return None;
492    }
493
494    if let Some((eigval1, eigval2)) = compute_2x2_eigvals(m) {
495        let x1 = eigval1 - m[(1, 1)].clone();
496        let x2 = eigval2 - m[(1, 1)].clone();
497
498        // NOTE: Choose the one that yields a larger x component.
499        // This is necessary for numerical stability of the normalization of the complex
500        // number.
501        if x1.clone().norm1() > x2.clone().norm1() {
502            Some(GivensRotation::new(x1, h10).0)
503        } else {
504            Some(GivensRotation::new(x2, h10).0)
505        }
506    } else {
507        None
508    }
509}
510
511impl<T: ComplexField, D: Dim, S: Storage<T, D, D>> SquareMatrix<T, D, S>
512where
513    D: DimSub<U1>, // For Hessenberg.
514    DefaultAllocator: Allocator<T, D, DimDiff<D, U1>>
515        + Allocator<T, DimDiff<D, U1>>
516        + Allocator<T, D, D>
517        + Allocator<T, D>,
518{
519    /// Computes the eigenvalues of this matrix.
520    #[must_use]
521    pub fn eigenvalues(&self) -> Option<OVector<T, D>> {
522        assert!(
523            self.is_square(),
524            "Unable to compute eigenvalues of a non-square matrix."
525        );
526
527        let mut work = Matrix::zeros_generic(self.shape_generic().0, Const::<1>);
528
529        // Special case for 2x2 matrices.
530        if self.nrows() == 2 {
531            // TODO: can we avoid this slicing
532            // (which is needed here just to transform D to U2)?
533            let me = self.fixed_slice::<2, 2>(0, 0);
534            return match compute_2x2_eigvals(&me) {
535                Some((a, b)) => {
536                    work[0] = a;
537                    work[1] = b;
538                    Some(work)
539                }
540                None => None,
541            };
542        }
543
544        // TODO: add balancing?
545        let schur = Schur::do_decompose(
546            self.clone_owned(),
547            &mut work,
548            T::RealField::default_epsilon(),
549            0,
550            false,
551        )
552        .unwrap();
553
554        if Schur::do_eigenvalues(&schur.1, &mut work) {
555            Some(work)
556        } else {
557            None
558        }
559    }
560
561    /// Computes the eigenvalues of this matrix.
562    #[must_use]
563    pub fn complex_eigenvalues(&self) -> OVector<NumComplex<T>, D>
564    // TODO: add balancing?
565    where
566        T: RealField,
567        DefaultAllocator: Allocator<NumComplex<T>, D>,
568    {
569        let dim = self.shape_generic().0;
570        let mut work = Matrix::zeros_generic(dim, Const::<1>);
571
572        let schur = Schur::do_decompose(
573            self.clone_owned(),
574            &mut work,
575            T::default_epsilon(),
576            0,
577            false,
578        )
579        .unwrap();
580        let mut eig = Matrix::uninit(dim, Const::<1>);
581        Schur::do_complex_eigenvalues(&schur.1, &mut eig);
582        // Safety: eig has been fully initialized by do_complex_eigenvalues.
583        unsafe { eig.assume_init() }
584    }
585}