nalgebra/linalg/
exp.rs

1//! This module provides the matrix exponent (exp) function to square matrices.
2//!
3use crate::{
4    base::{
5        allocator::Allocator,
6        dimension::{Const, Dim, DimMin, DimMinimum},
7        DefaultAllocator,
8    },
9    convert, try_convert, ComplexField, OMatrix, RealField,
10};
11
12use crate::num::Zero;
13
14/// Precomputed factorials for integers in range `0..=34`.
15/// Note: `35!` does not fit into 128 bits.
16// TODO: find a better place for this array?
17const FACTORIAL: [u128; 35] = [
18    1,
19    1,
20    2,
21    6,
22    24,
23    120,
24    720,
25    5040,
26    40320,
27    362880,
28    3628800,
29    39916800,
30    479001600,
31    6227020800,
32    87178291200,
33    1307674368000,
34    20922789888000,
35    355687428096000,
36    6402373705728000,
37    121645100408832000,
38    2432902008176640000,
39    51090942171709440000,
40    1124000727777607680000,
41    25852016738884976640000,
42    620448401733239439360000,
43    15511210043330985984000000,
44    403291461126605635584000000,
45    10888869450418352160768000000,
46    304888344611713860501504000000,
47    8841761993739701954543616000000,
48    265252859812191058636308480000000,
49    8222838654177922817725562880000000,
50    263130836933693530167218012160000000,
51    8683317618811886495518194401280000000,
52    295232799039604140847618609643520000000,
53];
54
55// https://github.com/scipy/scipy/blob/c1372d8aa90a73d8a52f135529293ff4edb98fc8/scipy/sparse/linalg/matfuncs.py
56struct ExpmPadeHelper<T, D>
57where
58    T: ComplexField,
59    D: DimMin<D>,
60    DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
61{
62    use_exact_norm: bool,
63    ident: OMatrix<T, D, D>,
64
65    a: OMatrix<T, D, D>,
66    a2: Option<OMatrix<T, D, D>>,
67    a4: Option<OMatrix<T, D, D>>,
68    a6: Option<OMatrix<T, D, D>>,
69    a8: Option<OMatrix<T, D, D>>,
70    a10: Option<OMatrix<T, D, D>>,
71
72    d4_exact: Option<T::RealField>,
73    d6_exact: Option<T::RealField>,
74    d8_exact: Option<T::RealField>,
75    d10_exact: Option<T::RealField>,
76
77    d4_approx: Option<T::RealField>,
78    d6_approx: Option<T::RealField>,
79    d8_approx: Option<T::RealField>,
80    d10_approx: Option<T::RealField>,
81}
82
83impl<T, D> ExpmPadeHelper<T, D>
84where
85    T: ComplexField,
86    D: DimMin<D>,
87    DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
88{
89    fn new(a: OMatrix<T, D, D>, use_exact_norm: bool) -> Self {
90        let (nrows, ncols) = a.shape_generic();
91        ExpmPadeHelper {
92            use_exact_norm,
93            ident: OMatrix::<T, D, D>::identity_generic(nrows, ncols),
94            a,
95            a2: None,
96            a4: None,
97            a6: None,
98            a8: None,
99            a10: None,
100            d4_exact: None,
101            d6_exact: None,
102            d8_exact: None,
103            d10_exact: None,
104            d4_approx: None,
105            d6_approx: None,
106            d8_approx: None,
107            d10_approx: None,
108        }
109    }
110
111    fn calc_a2(&mut self) {
112        if self.a2.is_none() {
113            self.a2 = Some(&self.a * &self.a);
114        }
115    }
116
117    fn calc_a4(&mut self) {
118        if self.a4.is_none() {
119            self.calc_a2();
120            let a2 = self.a2.as_ref().unwrap();
121            self.a4 = Some(a2 * a2);
122        }
123    }
124
125    fn calc_a6(&mut self) {
126        if self.a6.is_none() {
127            self.calc_a2();
128            self.calc_a4();
129            let a2 = self.a2.as_ref().unwrap();
130            let a4 = self.a4.as_ref().unwrap();
131            self.a6 = Some(a4 * a2);
132        }
133    }
134
135    fn calc_a8(&mut self) {
136        if self.a8.is_none() {
137            self.calc_a2();
138            self.calc_a6();
139            let a2 = self.a2.as_ref().unwrap();
140            let a6 = self.a6.as_ref().unwrap();
141            self.a8 = Some(a6 * a2);
142        }
143    }
144
145    fn calc_a10(&mut self) {
146        if self.a10.is_none() {
147            self.calc_a4();
148            self.calc_a6();
149            let a4 = self.a4.as_ref().unwrap();
150            let a6 = self.a6.as_ref().unwrap();
151            self.a10 = Some(a6 * a4);
152        }
153    }
154
155    fn d4_tight(&mut self) -> T::RealField {
156        if self.d4_exact.is_none() {
157            self.calc_a4();
158            self.d4_exact = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
159        }
160        self.d4_exact.clone().unwrap()
161    }
162
163    fn d6_tight(&mut self) -> T::RealField {
164        if self.d6_exact.is_none() {
165            self.calc_a6();
166            self.d6_exact = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
167        }
168        self.d6_exact.clone().unwrap()
169    }
170
171    fn d8_tight(&mut self) -> T::RealField {
172        if self.d8_exact.is_none() {
173            self.calc_a8();
174            self.d8_exact = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
175        }
176        self.d8_exact.clone().unwrap()
177    }
178
179    fn d10_tight(&mut self) -> T::RealField {
180        if self.d10_exact.is_none() {
181            self.calc_a10();
182            self.d10_exact = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
183        }
184        self.d10_exact.clone().unwrap()
185    }
186
187    fn d4_loose(&mut self) -> T::RealField {
188        if self.use_exact_norm {
189            return self.d4_tight();
190        }
191
192        if self.d4_exact.is_some() {
193            return self.d4_exact.clone().unwrap();
194        }
195
196        if self.d4_approx.is_none() {
197            self.calc_a4();
198            self.d4_approx = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
199        }
200
201        self.d4_approx.clone().unwrap()
202    }
203
204    fn d6_loose(&mut self) -> T::RealField {
205        if self.use_exact_norm {
206            return self.d6_tight();
207        }
208
209        if self.d6_exact.is_some() {
210            return self.d6_exact.clone().unwrap();
211        }
212
213        if self.d6_approx.is_none() {
214            self.calc_a6();
215            self.d6_approx = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
216        }
217
218        self.d6_approx.clone().unwrap()
219    }
220
221    fn d8_loose(&mut self) -> T::RealField {
222        if self.use_exact_norm {
223            return self.d8_tight();
224        }
225
226        if self.d8_exact.is_some() {
227            return self.d8_exact.clone().unwrap();
228        }
229
230        if self.d8_approx.is_none() {
231            self.calc_a8();
232            self.d8_approx = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
233        }
234
235        self.d8_approx.clone().unwrap()
236    }
237
238    fn d10_loose(&mut self) -> T::RealField {
239        if self.use_exact_norm {
240            return self.d10_tight();
241        }
242
243        if self.d10_exact.is_some() {
244            return self.d10_exact.clone().unwrap();
245        }
246
247        if self.d10_approx.is_none() {
248            self.calc_a10();
249            self.d10_approx = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
250        }
251
252        self.d10_approx.clone().unwrap()
253    }
254
255    fn pade3(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
256        let b: [T; 4] = [convert(120.0), convert(60.0), convert(12.0), convert(1.0)];
257        self.calc_a2();
258        let a2 = self.a2.as_ref().unwrap();
259        let u = &self.a * (a2 * b[3].clone() + &self.ident * b[1].clone());
260        let v = a2 * b[2].clone() + &self.ident * b[0].clone();
261        (u, v)
262    }
263
264    fn pade5(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
265        let b: [T; 6] = [
266            convert(30240.0),
267            convert(15120.0),
268            convert(3360.0),
269            convert(420.0),
270            convert(30.0),
271            convert(1.0),
272        ];
273        self.calc_a2();
274        self.calc_a6();
275        let u = &self.a
276            * (self.a4.as_ref().unwrap() * b[5].clone()
277                + self.a2.as_ref().unwrap() * b[3].clone()
278                + &self.ident * b[1].clone());
279        let v = self.a4.as_ref().unwrap() * b[4].clone()
280            + self.a2.as_ref().unwrap() * b[2].clone()
281            + &self.ident * b[0].clone();
282        (u, v)
283    }
284
285    fn pade7(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
286        let b: [T; 8] = [
287            convert(17_297_280.0),
288            convert(8_648_640.0),
289            convert(1_995_840.0),
290            convert(277_200.0),
291            convert(25_200.0),
292            convert(1_512.0),
293            convert(56.0),
294            convert(1.0),
295        ];
296        self.calc_a2();
297        self.calc_a4();
298        self.calc_a6();
299        let u = &self.a
300            * (self.a6.as_ref().unwrap() * b[7].clone()
301                + self.a4.as_ref().unwrap() * b[5].clone()
302                + self.a2.as_ref().unwrap() * b[3].clone()
303                + &self.ident * b[1].clone());
304        let v = self.a6.as_ref().unwrap() * b[6].clone()
305            + self.a4.as_ref().unwrap() * b[4].clone()
306            + self.a2.as_ref().unwrap() * b[2].clone()
307            + &self.ident * b[0].clone();
308        (u, v)
309    }
310
311    fn pade9(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
312        let b: [T; 10] = [
313            convert(17_643_225_600.0),
314            convert(8_821_612_800.0),
315            convert(2_075_673_600.0),
316            convert(302_702_400.0),
317            convert(30_270_240.0),
318            convert(2_162_160.0),
319            convert(110_880.0),
320            convert(3_960.0),
321            convert(90.0),
322            convert(1.0),
323        ];
324        self.calc_a2();
325        self.calc_a4();
326        self.calc_a6();
327        self.calc_a8();
328        let u = &self.a
329            * (self.a8.as_ref().unwrap() * b[9].clone()
330                + self.a6.as_ref().unwrap() * b[7].clone()
331                + self.a4.as_ref().unwrap() * b[5].clone()
332                + self.a2.as_ref().unwrap() * b[3].clone()
333                + &self.ident * b[1].clone());
334        let v = self.a8.as_ref().unwrap() * b[8].clone()
335            + self.a6.as_ref().unwrap() * b[6].clone()
336            + self.a4.as_ref().unwrap() * b[4].clone()
337            + self.a2.as_ref().unwrap() * b[2].clone()
338            + &self.ident * b[0].clone();
339        (u, v)
340    }
341
342    fn pade13_scaled(&mut self, s: u64) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
343        let b: [T; 14] = [
344            convert(64_764_752_532_480_000.0),
345            convert(32_382_376_266_240_000.0),
346            convert(7_771_770_303_897_600.0),
347            convert(1_187_353_796_428_800.0),
348            convert(129_060_195_264_000.0),
349            convert(10_559_470_521_600.0),
350            convert(670_442_572_800.0),
351            convert(33_522_128_640.0),
352            convert(1_323_241_920.0),
353            convert(40_840_800.0),
354            convert(960_960.0),
355            convert(16_380.0),
356            convert(182.0),
357            convert(1.0),
358        ];
359        let s = s as f64;
360
361        let mb = &self.a * convert::<f64, T>(2.0_f64.powf(-s));
362        self.calc_a2();
363        self.calc_a4();
364        self.calc_a6();
365        let mb2 = self.a2.as_ref().unwrap() * convert::<f64, T>(2.0_f64.powf(-2.0 * s));
366        let mb4 = self.a4.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-4.0 * s));
367        let mb6 = self.a6.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-6.0 * s));
368
369        let u2 = &mb6 * (&mb6 * b[13].clone() + &mb4 * b[11].clone() + &mb2 * b[9].clone());
370        let u = &mb
371            * (&u2
372                + &mb6 * b[7].clone()
373                + &mb4 * b[5].clone()
374                + &mb2 * b[3].clone()
375                + &self.ident * b[1].clone());
376        let v2 = &mb6 * (&mb6 * b[12].clone() + &mb4 * b[10].clone() + &mb2 * b[8].clone());
377        let v = v2
378            + &mb6 * b[6].clone()
379            + &mb4 * b[4].clone()
380            + &mb2 * b[2].clone()
381            + &self.ident * b[0].clone();
382        (u, v)
383    }
384}
385
386/// Compute `n!`
387#[inline(always)]
388fn factorial(n: usize) -> u128 {
389    match FACTORIAL.get(n) {
390        Some(f) => *f,
391        None => panic!("{}! is greater than u128::MAX", n),
392    }
393}
394
395/// Compute the 1-norm of a non-negative integer power of a non-negative matrix.
396fn onenorm_matrix_power_nonm<T, D>(a: &OMatrix<T, D, D>, p: usize) -> T
397where
398    T: RealField,
399    D: Dim,
400    DefaultAllocator: Allocator<T, D, D> + Allocator<T, D>,
401{
402    let nrows = a.shape_generic().0;
403    let mut v = crate::OVector::<T, D>::repeat_generic(nrows, Const::<1>, convert(1.0));
404    let m = a.transpose();
405
406    for _ in 0..p {
407        v = &m * v;
408    }
409
410    v.max()
411}
412
413fn ell<T, D>(a: &OMatrix<T, D, D>, m: usize) -> u64
414where
415    T: ComplexField,
416    D: Dim,
417    DefaultAllocator: Allocator<T, D, D>
418        + Allocator<T, D>
419        + Allocator<T::RealField, D>
420        + Allocator<T::RealField, D, D>,
421{
422    let a_abs = a.map(|x| x.abs());
423
424    let a_abs_onenorm = onenorm_matrix_power_nonm(&a_abs, 2 * m + 1);
425
426    if a_abs_onenorm == <T as ComplexField>::RealField::zero() {
427        return 0;
428    }
429
430    // 2m choose m = (2m)!/(m! * (2m-m)!) = (2m)!/((m!)^2)
431    let m_factorial = factorial(m);
432    let choose_2m_m = factorial(2 * m) / (m_factorial * m_factorial);
433
434    let abs_c_recip = choose_2m_m * factorial(2 * m + 1);
435    let alpha = a_abs_onenorm / one_norm(a);
436    let alpha: f64 = try_convert(alpha).unwrap() / abs_c_recip as f64;
437
438    let u = 2_f64.powf(-53.0);
439    let log2_alpha_div_u = (alpha / u).log2();
440    let value = (log2_alpha_div_u / (2.0 * m as f64)).ceil();
441    if value > 0.0 {
442        value as u64
443    } else {
444        0
445    }
446}
447
448fn solve_p_q<T, D>(u: OMatrix<T, D, D>, v: OMatrix<T, D, D>) -> OMatrix<T, D, D>
449where
450    T: ComplexField,
451    D: DimMin<D, Output = D>,
452    DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
453{
454    let p = &u + &v;
455    let q = &v - &u;
456
457    q.lu().solve(&p).unwrap()
458}
459
460fn one_norm<T, D>(m: &OMatrix<T, D, D>) -> T::RealField
461where
462    T: ComplexField,
463    D: Dim,
464    DefaultAllocator: Allocator<T, D, D>,
465{
466    let mut max = <T as ComplexField>::RealField::zero();
467
468    for i in 0..m.ncols() {
469        let col = m.column(i);
470        max = max.max(
471            col.iter()
472                .fold(<T as ComplexField>::RealField::zero(), |a, b| {
473                    a + b.clone().abs()
474                }),
475        );
476    }
477
478    max
479}
480
481impl<T: ComplexField, D> OMatrix<T, D, D>
482where
483    D: DimMin<D, Output = D>,
484    DefaultAllocator: Allocator<T, D, D>
485        + Allocator<(usize, usize), DimMinimum<D, D>>
486        + Allocator<T, D>
487        + Allocator<T::RealField, D>
488        + Allocator<T::RealField, D, D>,
489{
490    /// Computes exponential of this matrix
491    #[must_use]
492    pub fn exp(&self) -> Self {
493        // Simple case
494        if self.nrows() == 1 {
495            return self.map(|v| v.exp());
496        }
497
498        let mut helper = ExpmPadeHelper::new(self.clone(), true);
499
500        let eta_1 = T::RealField::max(helper.d4_loose(), helper.d6_loose());
501        if eta_1 < convert(1.495_585_217_958_292e-2) && ell(&helper.a, 3) == 0 {
502            let (u, v) = helper.pade3();
503            return solve_p_q(u, v);
504        }
505
506        let eta_2 = T::RealField::max(helper.d4_tight(), helper.d6_loose());
507        if eta_2 < convert(2.539_398_330_063_23e-1) && ell(&helper.a, 5) == 0 {
508            let (u, v) = helper.pade5();
509            return solve_p_q(u, v);
510        }
511
512        let eta_3 = T::RealField::max(helper.d6_tight(), helper.d8_loose());
513        if eta_3 < convert(9.504_178_996_162_932e-1) && ell(&helper.a, 7) == 0 {
514            let (u, v) = helper.pade7();
515            return solve_p_q(u, v);
516        }
517        if eta_3 < convert(2.097_847_961_257_068e0) && ell(&helper.a, 9) == 0 {
518            let (u, v) = helper.pade9();
519            return solve_p_q(u, v);
520        }
521
522        let eta_4 = T::RealField::max(helper.d8_loose(), helper.d10_loose());
523        let eta_5 = T::RealField::min(eta_3, eta_4);
524        let theta_13 = convert(4.25);
525
526        let mut s = if eta_5 == T::RealField::zero() {
527            0
528        } else {
529            let l2 = try_convert((eta_5 / theta_13).log2().ceil()).unwrap();
530
531            if l2 < 0.0 {
532                0
533            } else {
534                l2 as u64
535            }
536        };
537
538        s += ell(
539            &(&helper.a * convert::<f64, T>(2.0_f64.powf(-(s as f64)))),
540            13,
541        );
542
543        let (u, v) = helper.pade13_scaled(s);
544        let mut x = solve_p_q(u, v);
545
546        for _ in 0..s {
547            x = &x * &x;
548        }
549        x
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    #[test]
556    #[allow(clippy::float_cmp)]
557    fn one_norm() {
558        use crate::Matrix3;
559        let m = Matrix3::new(-3.0, 5.0, 7.0, 2.0, 6.0, 4.0, 0.0, 2.0, 8.0);
560
561        assert_eq!(super::one_norm(&m), 19.0);
562    }
563}