nalgebra/linalg/
symmetric_eigen.rs

1#[cfg(feature = "serde-serialize-no-std")]
2use serde::{Deserialize, Serialize};
3
4use approx::AbsDiffEq;
5use num::Zero;
6
7use crate::allocator::Allocator;
8use crate::base::{DefaultAllocator, Matrix2, OMatrix, OVector, SquareMatrix, Vector2};
9use crate::dimension::{Dim, DimDiff, DimSub, U1};
10use crate::storage::Storage;
11use simba::scalar::ComplexField;
12
13use crate::linalg::givens::GivensRotation;
14use crate::linalg::SymmetricTridiagonal;
15
16/// Eigendecomposition of a symmetric matrix.
17#[cfg_attr(feature = "serde-serialize-no-std", derive(Serialize, Deserialize))]
18#[cfg_attr(
19    feature = "serde-serialize-no-std",
20    serde(bound(serialize = "DefaultAllocator: Allocator<T, D, D> +
21                           Allocator<T::RealField, D>,
22         OVector<T::RealField, D>: Serialize,
23         OMatrix<T, D, D>: Serialize"))
24)]
25#[cfg_attr(
26    feature = "serde-serialize-no-std",
27    serde(bound(deserialize = "DefaultAllocator: Allocator<T, D, D> +
28                           Allocator<T::RealField, D>,
29         OVector<T::RealField, D>: Deserialize<'de>,
30         OMatrix<T, D, D>: Deserialize<'de>"))
31)]
32#[derive(Clone, Debug)]
33pub struct SymmetricEigen<T: ComplexField, D: Dim>
34where
35    DefaultAllocator: Allocator<T, D, D> + Allocator<T::RealField, D>,
36{
37    /// The eigenvectors of the decomposed matrix.
38    pub eigenvectors: OMatrix<T, D, D>,
39
40    /// The unsorted eigenvalues of the decomposed matrix.
41    pub eigenvalues: OVector<T::RealField, D>,
42}
43
44impl<T: ComplexField, D: Dim> Copy for SymmetricEigen<T, D>
45where
46    DefaultAllocator: Allocator<T, D, D> + Allocator<T::RealField, D>,
47    OMatrix<T, D, D>: Copy,
48    OVector<T::RealField, D>: Copy,
49{
50}
51
52impl<T: ComplexField, D: Dim> SymmetricEigen<T, D>
53where
54    DefaultAllocator: Allocator<T, D, D> + Allocator<T::RealField, D>,
55{
56    /// Computes the eigendecomposition of the given symmetric matrix.
57    ///
58    /// Only the lower-triangular parts (including its diagonal) of `m` is read.
59    pub fn new(m: OMatrix<T, D, D>) -> Self
60    where
61        D: DimSub<U1>,
62        DefaultAllocator: Allocator<T, DimDiff<D, U1>> + Allocator<T::RealField, DimDiff<D, U1>>,
63    {
64        Self::try_new(m, T::RealField::default_epsilon(), 0).unwrap()
65    }
66
67    /// Computes the eigendecomposition of the given symmetric matrix with user-specified
68    /// convergence parameters.
69    ///
70    /// Only the lower-triangular part (including its diagonal) of `m` is read.
71    ///
72    /// # Arguments
73    ///
74    /// * `eps`       − tolerance used to determine when a value converged to 0.
75    /// * `max_niter` − maximum total number of iterations performed by the algorithm. If this
76    /// number of iteration is exceeded, `None` is returned. If `niter == 0`, then the algorithm
77    /// continues indefinitely until convergence.
78    pub fn try_new(m: OMatrix<T, D, D>, eps: T::RealField, max_niter: usize) -> Option<Self>
79    where
80        D: DimSub<U1>,
81        DefaultAllocator: Allocator<T, DimDiff<D, U1>> + Allocator<T::RealField, DimDiff<D, U1>>,
82    {
83        Self::do_decompose(m, true, eps, max_niter).map(|(vals, vecs)| SymmetricEigen {
84            eigenvectors: vecs.unwrap(),
85            eigenvalues: vals,
86        })
87    }
88
89    fn do_decompose(
90        mut matrix: OMatrix<T, D, D>,
91        eigenvectors: bool,
92        eps: T::RealField,
93        max_niter: usize,
94    ) -> Option<(OVector<T::RealField, D>, Option<OMatrix<T, D, D>>)>
95    where
96        D: DimSub<U1>,
97        DefaultAllocator: Allocator<T, DimDiff<D, U1>> + Allocator<T::RealField, DimDiff<D, U1>>,
98    {
99        assert!(
100            matrix.is_square(),
101            "Unable to compute the eigendecomposition of a non-square matrix."
102        );
103        let dim = matrix.nrows();
104        let m_amax = matrix.camax();
105
106        if !m_amax.is_zero() {
107            matrix.unscale_mut(m_amax.clone());
108        }
109
110        let (mut q_mat, mut diag, mut off_diag);
111
112        if eigenvectors {
113            let res = SymmetricTridiagonal::new(matrix).unpack();
114            q_mat = Some(res.0);
115            diag = res.1;
116            off_diag = res.2;
117        } else {
118            let res = SymmetricTridiagonal::new(matrix).unpack_tridiagonal();
119            q_mat = None;
120            diag = res.0;
121            off_diag = res.1;
122        }
123
124        if dim == 1 {
125            diag.scale_mut(m_amax);
126            return Some((diag, q_mat));
127        }
128
129        let mut niter = 0;
130        let (mut start, mut end) =
131            Self::delimit_subproblem(&diag, &mut off_diag, dim - 1, eps.clone());
132
133        while end != start {
134            let subdim = end - start + 1;
135
136            #[allow(clippy::comparison_chain)]
137            if subdim > 2 {
138                let m = end - 1;
139                let n = end;
140
141                let mut vec = Vector2::new(
142                    diag[start].clone()
143                        - wilkinson_shift(
144                            diag[m].clone().clone(),
145                            diag[n].clone(),
146                            off_diag[m].clone().clone(),
147                        ),
148                    off_diag[start].clone(),
149                );
150
151                for i in start..n {
152                    let j = i + 1;
153
154                    if let Some((rot, norm)) = GivensRotation::cancel_y(&vec) {
155                        if i > start {
156                            // Not the first iteration.
157                            off_diag[i - 1] = norm;
158                        }
159
160                        let mii = diag[i].clone();
161                        let mjj = diag[j].clone();
162                        let mij = off_diag[i].clone();
163
164                        let cc = rot.c() * rot.c();
165                        let ss = rot.s() * rot.s();
166                        let cs = rot.c() * rot.s();
167
168                        let b = cs.clone() * crate::convert(2.0) * mij.clone();
169
170                        diag[i] = (cc.clone() * mii.clone() + ss.clone() * mjj.clone()) - b.clone();
171                        diag[j] = (ss.clone() * mii.clone() + cc.clone() * mjj.clone()) + b;
172                        off_diag[i] = cs * (mii - mjj) + mij * (cc - ss);
173
174                        if i != n - 1 {
175                            vec.x = off_diag[i].clone();
176                            vec.y = -rot.s() * off_diag[i + 1].clone();
177                            off_diag[i + 1] *= rot.c();
178                        }
179
180                        if let Some(ref mut q) = q_mat {
181                            let rot = GivensRotation::new_unchecked(rot.c(), T::from_real(rot.s()));
182                            rot.inverse().rotate_rows(&mut q.fixed_columns_mut::<2>(i));
183                        }
184                    } else {
185                        break;
186                    }
187                }
188
189                if off_diag[m].clone().norm1()
190                    <= eps.clone() * (diag[m].clone().norm1() + diag[n].clone().norm1())
191                {
192                    end -= 1;
193                }
194            } else if subdim == 2 {
195                let m = Matrix2::new(
196                    diag[start].clone(),
197                    off_diag[start].clone().conjugate(),
198                    off_diag[start].clone(),
199                    diag[start + 1].clone(),
200                );
201                let eigvals = m.eigenvalues().unwrap();
202                let basis = Vector2::new(
203                    eigvals.x.clone() - diag[start + 1].clone(),
204                    off_diag[start].clone(),
205                );
206
207                diag[start] = eigvals[0].clone();
208                diag[start + 1] = eigvals[1].clone();
209
210                if let Some(ref mut q) = q_mat {
211                    if let Some((rot, _)) =
212                        GivensRotation::try_new(basis.x.clone(), basis.y.clone(), eps.clone())
213                    {
214                        let rot = GivensRotation::new_unchecked(rot.c(), T::from_real(rot.s()));
215                        rot.rotate_rows(&mut q.fixed_columns_mut::<2>(start));
216                    }
217                }
218
219                end -= 1;
220            }
221
222            // Re-delimit the subproblem in case some decoupling occurred.
223            let sub = Self::delimit_subproblem(&diag, &mut off_diag, end, eps.clone());
224
225            start = sub.0;
226            end = sub.1;
227
228            niter += 1;
229            if niter == max_niter {
230                return None;
231            }
232        }
233
234        diag.scale_mut(m_amax);
235
236        Some((diag, q_mat))
237    }
238
239    fn delimit_subproblem(
240        diag: &OVector<T::RealField, D>,
241        off_diag: &mut OVector<T::RealField, DimDiff<D, U1>>,
242        end: usize,
243        eps: T::RealField,
244    ) -> (usize, usize)
245    where
246        D: DimSub<U1>,
247        DefaultAllocator: Allocator<T::RealField, DimDiff<D, U1>>,
248    {
249        let mut n = end;
250
251        while n > 0 {
252            let m = n - 1;
253
254            if off_diag[m].clone().norm1()
255                > eps.clone() * (diag[n].clone().norm1() + diag[m].clone().norm1())
256            {
257                break;
258            }
259
260            n -= 1;
261        }
262
263        if n == 0 {
264            return (0, 0);
265        }
266
267        let mut new_start = n - 1;
268        while new_start > 0 {
269            let m = new_start - 1;
270
271            if off_diag[m].clone().is_zero()
272                || off_diag[m].clone().norm1()
273                    <= eps.clone() * (diag[new_start].clone().norm1() + diag[m].clone().norm1())
274            {
275                off_diag[m] = T::RealField::zero();
276                break;
277            }
278
279            new_start -= 1;
280        }
281
282        (new_start, n)
283    }
284
285    /// Rebuild the original matrix.
286    ///
287    /// This is useful if some of the eigenvalues have been manually modified.
288    #[must_use]
289    pub fn recompose(&self) -> OMatrix<T, D, D> {
290        let mut u_t = self.eigenvectors.clone();
291        for i in 0..self.eigenvalues.len() {
292            let val = self.eigenvalues[i].clone();
293            u_t.column_mut(i).scale_mut(val);
294        }
295        u_t.adjoint_mut();
296        &self.eigenvectors * u_t
297    }
298}
299
300/// Computes the wilkinson shift, i.e., the 2x2 symmetric matrix eigenvalue to its tailing
301/// component `tnn`.
302///
303/// The inputs are interpreted as the 2x2 matrix:
304///     tmm  tmn
305///     tmn  tnn
306pub fn wilkinson_shift<T: ComplexField>(tmm: T, tnn: T, tmn: T) -> T {
307    let sq_tmn = tmn.clone() * tmn;
308    if !sq_tmn.is_zero() {
309        // We have the guarantee that the denominator won't be zero.
310        let d = (tmm - tnn.clone()) * crate::convert(0.5);
311        tnn - sq_tmn.clone() / (d.clone() + d.clone().signum() * (d.clone() * d + sq_tmn).sqrt())
312    } else {
313        tnn
314    }
315}
316
317/*
318 *
319 * Computations of eigenvalues for symmetric matrices.
320 *
321 */
322impl<T: ComplexField, D: DimSub<U1>, S: Storage<T, D, D>> SquareMatrix<T, D, S>
323where
324    DefaultAllocator: Allocator<T, D, D>
325        + Allocator<T, DimDiff<D, U1>>
326        + Allocator<T::RealField, D>
327        + Allocator<T::RealField, DimDiff<D, U1>>,
328{
329    /// Computes the eigenvalues of this symmetric matrix.
330    ///
331    /// Only the lower-triangular part of the matrix is read.
332    #[must_use]
333    pub fn symmetric_eigenvalues(&self) -> OVector<T::RealField, D> {
334        SymmetricEigen::do_decompose(
335            self.clone_owned(),
336            false,
337            T::RealField::default_epsilon(),
338            0,
339        )
340        .unwrap()
341        .0
342    }
343}
344
345#[cfg(test)]
346mod test {
347    use crate::base::Matrix2;
348
349    fn expected_shift(m: Matrix2<f64>) -> f64 {
350        let vals = m.eigenvalues().unwrap();
351
352        if (vals.x - m.m22).abs() < (vals.y - m.m22).abs() {
353            vals.x
354        } else {
355            vals.y
356        }
357    }
358
359    #[cfg(feature = "rand")]
360    #[test]
361    fn wilkinson_shift_random() {
362        for _ in 0..1000 {
363            let m = Matrix2::new_random();
364            let m = m * m.transpose();
365
366            let expected = expected_shift(m);
367            let computed = super::wilkinson_shift(m.m11, m.m22, m.m12);
368            assert!(relative_eq!(expected, computed, epsilon = 1.0e-7));
369        }
370    }
371
372    #[test]
373    fn wilkinson_shift_zero() {
374        let m = Matrix2::new(0.0, 0.0, 0.0, 0.0);
375        assert!(relative_eq!(
376            expected_shift(m),
377            super::wilkinson_shift(m.m11, m.m22, m.m12)
378        ));
379    }
380
381    #[test]
382    fn wilkinson_shift_zero_diagonal() {
383        let m = Matrix2::new(0.0, 42.0, 42.0, 0.0);
384        assert!(relative_eq!(
385            expected_shift(m),
386            super::wilkinson_shift(m.m11, m.m22, m.m12)
387        ));
388    }
389
390    #[test]
391    fn wilkinson_shift_zero_off_diagonal() {
392        let m = Matrix2::new(42.0, 0.0, 0.0, 64.0);
393        assert!(relative_eq!(
394            expected_shift(m),
395            super::wilkinson_shift(m.m11, m.m22, m.m12)
396        ));
397    }
398
399    #[test]
400    fn wilkinson_shift_zero_trace() {
401        let m = Matrix2::new(42.0, 20.0, 20.0, -42.0);
402        assert!(relative_eq!(
403            expected_shift(m),
404            super::wilkinson_shift(m.m11, m.m22, m.m12)
405        ));
406    }
407
408    #[test]
409    fn wilkinson_shift_zero_diag_diff_and_zero_off_diagonal() {
410        let m = Matrix2::new(42.0, 0.0, 0.0, 42.0);
411        assert!(relative_eq!(
412            expected_shift(m),
413            super::wilkinson_shift(m.m11, m.m22, m.m12)
414        ));
415    }
416
417    #[test]
418    fn wilkinson_shift_zero_det() {
419        let m = Matrix2::new(2.0, 4.0, 4.0, 8.0);
420        assert!(relative_eq!(
421            expected_shift(m),
422            super::wilkinson_shift(m.m11, m.m22, m.m12)
423        ));
424    }
425}