1use crate::{
4 base::{
5 allocator::Allocator,
6 dimension::{Const, Dim, DimMin, DimMinimum},
7 DefaultAllocator,
8 },
9 convert, try_convert, ComplexField, OMatrix, RealField,
10};
11
12use crate::num::Zero;
13
14const FACTORIAL: [u128; 35] = [
18 1,
19 1,
20 2,
21 6,
22 24,
23 120,
24 720,
25 5040,
26 40320,
27 362880,
28 3628800,
29 39916800,
30 479001600,
31 6227020800,
32 87178291200,
33 1307674368000,
34 20922789888000,
35 355687428096000,
36 6402373705728000,
37 121645100408832000,
38 2432902008176640000,
39 51090942171709440000,
40 1124000727777607680000,
41 25852016738884976640000,
42 620448401733239439360000,
43 15511210043330985984000000,
44 403291461126605635584000000,
45 10888869450418352160768000000,
46 304888344611713860501504000000,
47 8841761993739701954543616000000,
48 265252859812191058636308480000000,
49 8222838654177922817725562880000000,
50 263130836933693530167218012160000000,
51 8683317618811886495518194401280000000,
52 295232799039604140847618609643520000000,
53];
54
55struct ExpmPadeHelper<T, D>
57where
58 T: ComplexField,
59 D: DimMin<D>,
60 DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
61{
62 use_exact_norm: bool,
63 ident: OMatrix<T, D, D>,
64
65 a: OMatrix<T, D, D>,
66 a2: Option<OMatrix<T, D, D>>,
67 a4: Option<OMatrix<T, D, D>>,
68 a6: Option<OMatrix<T, D, D>>,
69 a8: Option<OMatrix<T, D, D>>,
70 a10: Option<OMatrix<T, D, D>>,
71
72 d4_exact: Option<T::RealField>,
73 d6_exact: Option<T::RealField>,
74 d8_exact: Option<T::RealField>,
75 d10_exact: Option<T::RealField>,
76
77 d4_approx: Option<T::RealField>,
78 d6_approx: Option<T::RealField>,
79 d8_approx: Option<T::RealField>,
80 d10_approx: Option<T::RealField>,
81}
82
83impl<T, D> ExpmPadeHelper<T, D>
84where
85 T: ComplexField,
86 D: DimMin<D>,
87 DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
88{
89 fn new(a: OMatrix<T, D, D>, use_exact_norm: bool) -> Self {
90 let (nrows, ncols) = a.shape_generic();
91 ExpmPadeHelper {
92 use_exact_norm,
93 ident: OMatrix::<T, D, D>::identity_generic(nrows, ncols),
94 a,
95 a2: None,
96 a4: None,
97 a6: None,
98 a8: None,
99 a10: None,
100 d4_exact: None,
101 d6_exact: None,
102 d8_exact: None,
103 d10_exact: None,
104 d4_approx: None,
105 d6_approx: None,
106 d8_approx: None,
107 d10_approx: None,
108 }
109 }
110
111 fn calc_a2(&mut self) {
112 if self.a2.is_none() {
113 self.a2 = Some(&self.a * &self.a);
114 }
115 }
116
117 fn calc_a4(&mut self) {
118 if self.a4.is_none() {
119 self.calc_a2();
120 let a2 = self.a2.as_ref().unwrap();
121 self.a4 = Some(a2 * a2);
122 }
123 }
124
125 fn calc_a6(&mut self) {
126 if self.a6.is_none() {
127 self.calc_a2();
128 self.calc_a4();
129 let a2 = self.a2.as_ref().unwrap();
130 let a4 = self.a4.as_ref().unwrap();
131 self.a6 = Some(a4 * a2);
132 }
133 }
134
135 fn calc_a8(&mut self) {
136 if self.a8.is_none() {
137 self.calc_a2();
138 self.calc_a6();
139 let a2 = self.a2.as_ref().unwrap();
140 let a6 = self.a6.as_ref().unwrap();
141 self.a8 = Some(a6 * a2);
142 }
143 }
144
145 fn calc_a10(&mut self) {
146 if self.a10.is_none() {
147 self.calc_a4();
148 self.calc_a6();
149 let a4 = self.a4.as_ref().unwrap();
150 let a6 = self.a6.as_ref().unwrap();
151 self.a10 = Some(a6 * a4);
152 }
153 }
154
155 fn d4_tight(&mut self) -> T::RealField {
156 if self.d4_exact.is_none() {
157 self.calc_a4();
158 self.d4_exact = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
159 }
160 self.d4_exact.clone().unwrap()
161 }
162
163 fn d6_tight(&mut self) -> T::RealField {
164 if self.d6_exact.is_none() {
165 self.calc_a6();
166 self.d6_exact = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
167 }
168 self.d6_exact.clone().unwrap()
169 }
170
171 fn d8_tight(&mut self) -> T::RealField {
172 if self.d8_exact.is_none() {
173 self.calc_a8();
174 self.d8_exact = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
175 }
176 self.d8_exact.clone().unwrap()
177 }
178
179 fn d10_tight(&mut self) -> T::RealField {
180 if self.d10_exact.is_none() {
181 self.calc_a10();
182 self.d10_exact = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
183 }
184 self.d10_exact.clone().unwrap()
185 }
186
187 fn d4_loose(&mut self) -> T::RealField {
188 if self.use_exact_norm {
189 return self.d4_tight();
190 }
191
192 if self.d4_exact.is_some() {
193 return self.d4_exact.clone().unwrap();
194 }
195
196 if self.d4_approx.is_none() {
197 self.calc_a4();
198 self.d4_approx = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
199 }
200
201 self.d4_approx.clone().unwrap()
202 }
203
204 fn d6_loose(&mut self) -> T::RealField {
205 if self.use_exact_norm {
206 return self.d6_tight();
207 }
208
209 if self.d6_exact.is_some() {
210 return self.d6_exact.clone().unwrap();
211 }
212
213 if self.d6_approx.is_none() {
214 self.calc_a6();
215 self.d6_approx = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
216 }
217
218 self.d6_approx.clone().unwrap()
219 }
220
221 fn d8_loose(&mut self) -> T::RealField {
222 if self.use_exact_norm {
223 return self.d8_tight();
224 }
225
226 if self.d8_exact.is_some() {
227 return self.d8_exact.clone().unwrap();
228 }
229
230 if self.d8_approx.is_none() {
231 self.calc_a8();
232 self.d8_approx = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
233 }
234
235 self.d8_approx.clone().unwrap()
236 }
237
238 fn d10_loose(&mut self) -> T::RealField {
239 if self.use_exact_norm {
240 return self.d10_tight();
241 }
242
243 if self.d10_exact.is_some() {
244 return self.d10_exact.clone().unwrap();
245 }
246
247 if self.d10_approx.is_none() {
248 self.calc_a10();
249 self.d10_approx = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
250 }
251
252 self.d10_approx.clone().unwrap()
253 }
254
255 fn pade3(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
256 let b: [T; 4] = [convert(120.0), convert(60.0), convert(12.0), convert(1.0)];
257 self.calc_a2();
258 let a2 = self.a2.as_ref().unwrap();
259 let u = &self.a * (a2 * b[3].clone() + &self.ident * b[1].clone());
260 let v = a2 * b[2].clone() + &self.ident * b[0].clone();
261 (u, v)
262 }
263
264 fn pade5(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
265 let b: [T; 6] = [
266 convert(30240.0),
267 convert(15120.0),
268 convert(3360.0),
269 convert(420.0),
270 convert(30.0),
271 convert(1.0),
272 ];
273 self.calc_a2();
274 self.calc_a6();
275 let u = &self.a
276 * (self.a4.as_ref().unwrap() * b[5].clone()
277 + self.a2.as_ref().unwrap() * b[3].clone()
278 + &self.ident * b[1].clone());
279 let v = self.a4.as_ref().unwrap() * b[4].clone()
280 + self.a2.as_ref().unwrap() * b[2].clone()
281 + &self.ident * b[0].clone();
282 (u, v)
283 }
284
285 fn pade7(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
286 let b: [T; 8] = [
287 convert(17_297_280.0),
288 convert(8_648_640.0),
289 convert(1_995_840.0),
290 convert(277_200.0),
291 convert(25_200.0),
292 convert(1_512.0),
293 convert(56.0),
294 convert(1.0),
295 ];
296 self.calc_a2();
297 self.calc_a4();
298 self.calc_a6();
299 let u = &self.a
300 * (self.a6.as_ref().unwrap() * b[7].clone()
301 + self.a4.as_ref().unwrap() * b[5].clone()
302 + self.a2.as_ref().unwrap() * b[3].clone()
303 + &self.ident * b[1].clone());
304 let v = self.a6.as_ref().unwrap() * b[6].clone()
305 + self.a4.as_ref().unwrap() * b[4].clone()
306 + self.a2.as_ref().unwrap() * b[2].clone()
307 + &self.ident * b[0].clone();
308 (u, v)
309 }
310
311 fn pade9(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
312 let b: [T; 10] = [
313 convert(17_643_225_600.0),
314 convert(8_821_612_800.0),
315 convert(2_075_673_600.0),
316 convert(302_702_400.0),
317 convert(30_270_240.0),
318 convert(2_162_160.0),
319 convert(110_880.0),
320 convert(3_960.0),
321 convert(90.0),
322 convert(1.0),
323 ];
324 self.calc_a2();
325 self.calc_a4();
326 self.calc_a6();
327 self.calc_a8();
328 let u = &self.a
329 * (self.a8.as_ref().unwrap() * b[9].clone()
330 + self.a6.as_ref().unwrap() * b[7].clone()
331 + self.a4.as_ref().unwrap() * b[5].clone()
332 + self.a2.as_ref().unwrap() * b[3].clone()
333 + &self.ident * b[1].clone());
334 let v = self.a8.as_ref().unwrap() * b[8].clone()
335 + self.a6.as_ref().unwrap() * b[6].clone()
336 + self.a4.as_ref().unwrap() * b[4].clone()
337 + self.a2.as_ref().unwrap() * b[2].clone()
338 + &self.ident * b[0].clone();
339 (u, v)
340 }
341
342 fn pade13_scaled(&mut self, s: u64) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
343 let b: [T; 14] = [
344 convert(64_764_752_532_480_000.0),
345 convert(32_382_376_266_240_000.0),
346 convert(7_771_770_303_897_600.0),
347 convert(1_187_353_796_428_800.0),
348 convert(129_060_195_264_000.0),
349 convert(10_559_470_521_600.0),
350 convert(670_442_572_800.0),
351 convert(33_522_128_640.0),
352 convert(1_323_241_920.0),
353 convert(40_840_800.0),
354 convert(960_960.0),
355 convert(16_380.0),
356 convert(182.0),
357 convert(1.0),
358 ];
359 let s = s as f64;
360
361 let mb = &self.a * convert::<f64, T>(2.0_f64.powf(-s));
362 self.calc_a2();
363 self.calc_a4();
364 self.calc_a6();
365 let mb2 = self.a2.as_ref().unwrap() * convert::<f64, T>(2.0_f64.powf(-2.0 * s));
366 let mb4 = self.a4.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-4.0 * s));
367 let mb6 = self.a6.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-6.0 * s));
368
369 let u2 = &mb6 * (&mb6 * b[13].clone() + &mb4 * b[11].clone() + &mb2 * b[9].clone());
370 let u = &mb
371 * (&u2
372 + &mb6 * b[7].clone()
373 + &mb4 * b[5].clone()
374 + &mb2 * b[3].clone()
375 + &self.ident * b[1].clone());
376 let v2 = &mb6 * (&mb6 * b[12].clone() + &mb4 * b[10].clone() + &mb2 * b[8].clone());
377 let v = v2
378 + &mb6 * b[6].clone()
379 + &mb4 * b[4].clone()
380 + &mb2 * b[2].clone()
381 + &self.ident * b[0].clone();
382 (u, v)
383 }
384}
385
386#[inline(always)]
388fn factorial(n: usize) -> u128 {
389 match FACTORIAL.get(n) {
390 Some(f) => *f,
391 None => panic!("{}! is greater than u128::MAX", n),
392 }
393}
394
395fn onenorm_matrix_power_nonm<T, D>(a: &OMatrix<T, D, D>, p: usize) -> T
397where
398 T: RealField,
399 D: Dim,
400 DefaultAllocator: Allocator<T, D, D> + Allocator<T, D>,
401{
402 let nrows = a.shape_generic().0;
403 let mut v = crate::OVector::<T, D>::repeat_generic(nrows, Const::<1>, convert(1.0));
404 let m = a.transpose();
405
406 for _ in 0..p {
407 v = &m * v;
408 }
409
410 v.max()
411}
412
413fn ell<T, D>(a: &OMatrix<T, D, D>, m: usize) -> u64
414where
415 T: ComplexField,
416 D: Dim,
417 DefaultAllocator: Allocator<T, D, D>
418 + Allocator<T, D>
419 + Allocator<T::RealField, D>
420 + Allocator<T::RealField, D, D>,
421{
422 let a_abs = a.map(|x| x.abs());
423
424 let a_abs_onenorm = onenorm_matrix_power_nonm(&a_abs, 2 * m + 1);
425
426 if a_abs_onenorm == <T as ComplexField>::RealField::zero() {
427 return 0;
428 }
429
430 let m_factorial = factorial(m);
432 let choose_2m_m = factorial(2 * m) / (m_factorial * m_factorial);
433
434 let abs_c_recip = choose_2m_m * factorial(2 * m + 1);
435 let alpha = a_abs_onenorm / one_norm(a);
436 let alpha: f64 = try_convert(alpha).unwrap() / abs_c_recip as f64;
437
438 let u = 2_f64.powf(-53.0);
439 let log2_alpha_div_u = (alpha / u).log2();
440 let value = (log2_alpha_div_u / (2.0 * m as f64)).ceil();
441 if value > 0.0 {
442 value as u64
443 } else {
444 0
445 }
446}
447
448fn solve_p_q<T, D>(u: OMatrix<T, D, D>, v: OMatrix<T, D, D>) -> OMatrix<T, D, D>
449where
450 T: ComplexField,
451 D: DimMin<D, Output = D>,
452 DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
453{
454 let p = &u + &v;
455 let q = &v - &u;
456
457 q.lu().solve(&p).unwrap()
458}
459
460fn one_norm<T, D>(m: &OMatrix<T, D, D>) -> T::RealField
461where
462 T: ComplexField,
463 D: Dim,
464 DefaultAllocator: Allocator<T, D, D>,
465{
466 let mut max = <T as ComplexField>::RealField::zero();
467
468 for i in 0..m.ncols() {
469 let col = m.column(i);
470 max = max.max(
471 col.iter()
472 .fold(<T as ComplexField>::RealField::zero(), |a, b| {
473 a + b.clone().abs()
474 }),
475 );
476 }
477
478 max
479}
480
481impl<T: ComplexField, D> OMatrix<T, D, D>
482where
483 D: DimMin<D, Output = D>,
484 DefaultAllocator: Allocator<T, D, D>
485 + Allocator<(usize, usize), DimMinimum<D, D>>
486 + Allocator<T, D>
487 + Allocator<T::RealField, D>
488 + Allocator<T::RealField, D, D>,
489{
490 #[must_use]
492 pub fn exp(&self) -> Self {
493 if self.nrows() == 1 {
495 return self.map(|v| v.exp());
496 }
497
498 let mut helper = ExpmPadeHelper::new(self.clone(), true);
499
500 let eta_1 = T::RealField::max(helper.d4_loose(), helper.d6_loose());
501 if eta_1 < convert(1.495_585_217_958_292e-2) && ell(&helper.a, 3) == 0 {
502 let (u, v) = helper.pade3();
503 return solve_p_q(u, v);
504 }
505
506 let eta_2 = T::RealField::max(helper.d4_tight(), helper.d6_loose());
507 if eta_2 < convert(2.539_398_330_063_23e-1) && ell(&helper.a, 5) == 0 {
508 let (u, v) = helper.pade5();
509 return solve_p_q(u, v);
510 }
511
512 let eta_3 = T::RealField::max(helper.d6_tight(), helper.d8_loose());
513 if eta_3 < convert(9.504_178_996_162_932e-1) && ell(&helper.a, 7) == 0 {
514 let (u, v) = helper.pade7();
515 return solve_p_q(u, v);
516 }
517 if eta_3 < convert(2.097_847_961_257_068e0) && ell(&helper.a, 9) == 0 {
518 let (u, v) = helper.pade9();
519 return solve_p_q(u, v);
520 }
521
522 let eta_4 = T::RealField::max(helper.d8_loose(), helper.d10_loose());
523 let eta_5 = T::RealField::min(eta_3, eta_4);
524 let theta_13 = convert(4.25);
525
526 let mut s = if eta_5 == T::RealField::zero() {
527 0
528 } else {
529 let l2 = try_convert((eta_5 / theta_13).log2().ceil()).unwrap();
530
531 if l2 < 0.0 {
532 0
533 } else {
534 l2 as u64
535 }
536 };
537
538 s += ell(
539 &(&helper.a * convert::<f64, T>(2.0_f64.powf(-(s as f64)))),
540 13,
541 );
542
543 let (u, v) = helper.pade13_scaled(s);
544 let mut x = solve_p_q(u, v);
545
546 for _ in 0..s {
547 x = &x * &x;
548 }
549 x
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 #[test]
556 #[allow(clippy::float_cmp)]
557 fn one_norm() {
558 use crate::Matrix3;
559 let m = Matrix3::new(-3.0, 5.0, 7.0, 2.0, 6.0, 4.0, 0.0, 2.0, 8.0);
560
561 assert_eq!(super::one_norm(&m), 19.0);
562 }
563}