nalgebra/linalg/
inverse.rs

1use simba::scalar::ComplexField;
2
3use crate::base::allocator::Allocator;
4use crate::base::dimension::Dim;
5use crate::base::storage::{Storage, StorageMut};
6use crate::base::{DefaultAllocator, OMatrix, SquareMatrix};
7
8use crate::linalg::lu;
9
10impl<T: ComplexField, D: Dim, S: Storage<T, D, D>> SquareMatrix<T, D, S> {
11    /// Attempts to invert this matrix.
12    #[inline]
13    #[must_use = "Did you mean to use try_inverse_mut()?"]
14    pub fn try_inverse(self) -> Option<OMatrix<T, D, D>>
15    where
16        DefaultAllocator: Allocator<T, D, D>,
17    {
18        let mut me = self.into_owned();
19        if me.try_inverse_mut() {
20            Some(me)
21        } else {
22            None
23        }
24    }
25}
26
27impl<T: ComplexField, D: Dim, S: StorageMut<T, D, D>> SquareMatrix<T, D, S> {
28    /// Attempts to invert this matrix in-place. Returns `false` and leaves `self` untouched if
29    /// inversion fails.
30    #[inline]
31    pub fn try_inverse_mut(&mut self) -> bool
32    where
33        DefaultAllocator: Allocator<T, D, D>,
34    {
35        assert!(self.is_square(), "Unable to invert a non-square matrix.");
36
37        let dim = self.shape().0;
38
39        unsafe {
40            match dim {
41                0 => true,
42                1 => {
43                    let determinant = self.get_unchecked((0, 0)).clone();
44                    if determinant.is_zero() {
45                        false
46                    } else {
47                        *self.get_unchecked_mut((0, 0)) = T::one() / determinant;
48                        true
49                    }
50                }
51                2 => {
52                    let m11 = self.get_unchecked((0, 0)).clone();
53                    let m12 = self.get_unchecked((0, 1)).clone();
54                    let m21 = self.get_unchecked((1, 0)).clone();
55                    let m22 = self.get_unchecked((1, 1)).clone();
56
57                    let determinant = m11.clone() * m22.clone() - m21.clone() * m12.clone();
58
59                    if determinant.is_zero() {
60                        false
61                    } else {
62                        *self.get_unchecked_mut((0, 0)) = m22 / determinant.clone();
63                        *self.get_unchecked_mut((0, 1)) = -m12 / determinant.clone();
64
65                        *self.get_unchecked_mut((1, 0)) = -m21 / determinant.clone();
66                        *self.get_unchecked_mut((1, 1)) = m11 / determinant;
67
68                        true
69                    }
70                }
71                3 => {
72                    let m11 = self.get_unchecked((0, 0)).clone();
73                    let m12 = self.get_unchecked((0, 1)).clone();
74                    let m13 = self.get_unchecked((0, 2)).clone();
75
76                    let m21 = self.get_unchecked((1, 0)).clone();
77                    let m22 = self.get_unchecked((1, 1)).clone();
78                    let m23 = self.get_unchecked((1, 2)).clone();
79
80                    let m31 = self.get_unchecked((2, 0)).clone();
81                    let m32 = self.get_unchecked((2, 1)).clone();
82                    let m33 = self.get_unchecked((2, 2)).clone();
83
84                    let minor_m12_m23 = m22.clone() * m33.clone() - m32.clone() * m23.clone();
85                    let minor_m11_m23 = m21.clone() * m33.clone() - m31.clone() * m23.clone();
86                    let minor_m11_m22 = m21.clone() * m32.clone() - m31.clone() * m22.clone();
87
88                    let determinant = m11.clone() * minor_m12_m23.clone()
89                        - m12.clone() * minor_m11_m23.clone()
90                        + m13.clone() * minor_m11_m22.clone();
91
92                    if determinant.is_zero() {
93                        false
94                    } else {
95                        *self.get_unchecked_mut((0, 0)) = minor_m12_m23 / determinant.clone();
96                        *self.get_unchecked_mut((0, 1)) = (m13.clone() * m32.clone()
97                            - m33.clone() * m12.clone())
98                            / determinant.clone();
99                        *self.get_unchecked_mut((0, 2)) = (m12.clone() * m23.clone()
100                            - m22.clone() * m13.clone())
101                            / determinant.clone();
102
103                        *self.get_unchecked_mut((1, 0)) = -minor_m11_m23 / determinant.clone();
104                        *self.get_unchecked_mut((1, 1)) =
105                            (m11.clone() * m33 - m31.clone() * m13.clone()) / determinant.clone();
106                        *self.get_unchecked_mut((1, 2)) =
107                            (m13 * m21.clone() - m23 * m11.clone()) / determinant.clone();
108
109                        *self.get_unchecked_mut((2, 0)) = minor_m11_m22 / determinant.clone();
110                        *self.get_unchecked_mut((2, 1)) =
111                            (m12.clone() * m31 - m32 * m11.clone()) / determinant.clone();
112                        *self.get_unchecked_mut((2, 2)) = (m11 * m22 - m21 * m12) / determinant;
113
114                        true
115                    }
116                }
117                4 => {
118                    let oself = self.clone_owned();
119                    do_inverse4(&oself, self)
120                }
121                _ => {
122                    let oself = self.clone_owned();
123                    lu::try_invert_to(oself, self)
124                }
125            }
126        }
127    }
128}
129
130// NOTE: this is an extremely efficient, loop-unrolled matrix inverse from MESA (MIT licensed).
131fn do_inverse4<T: ComplexField, D: Dim, S: StorageMut<T, D, D>>(
132    m: &OMatrix<T, D, D>,
133    out: &mut SquareMatrix<T, D, S>,
134) -> bool
135where
136    DefaultAllocator: Allocator<T, D, D>,
137{
138    let m = m.as_slice();
139
140    out[(0, 0)] = m[5].clone() * m[10].clone() * m[15].clone()
141        - m[5].clone() * m[11].clone() * m[14].clone()
142        - m[9].clone() * m[6].clone() * m[15].clone()
143        + m[9].clone() * m[7].clone() * m[14].clone()
144        + m[13].clone() * m[6].clone() * m[11].clone()
145        - m[13].clone() * m[7].clone() * m[10].clone();
146
147    out[(1, 0)] = -m[1].clone() * m[10].clone() * m[15].clone()
148        + m[1].clone() * m[11].clone() * m[14].clone()
149        + m[9].clone() * m[2].clone() * m[15].clone()
150        - m[9].clone() * m[3].clone() * m[14].clone()
151        - m[13].clone() * m[2].clone() * m[11].clone()
152        + m[13].clone() * m[3].clone() * m[10].clone();
153
154    out[(2, 0)] = m[1].clone() * m[6].clone() * m[15].clone()
155        - m[1].clone() * m[7].clone() * m[14].clone()
156        - m[5].clone() * m[2].clone() * m[15].clone()
157        + m[5].clone() * m[3].clone() * m[14].clone()
158        + m[13].clone() * m[2].clone() * m[7].clone()
159        - m[13].clone() * m[3].clone() * m[6].clone();
160
161    out[(3, 0)] = -m[1].clone() * m[6].clone() * m[11].clone()
162        + m[1].clone() * m[7].clone() * m[10].clone()
163        + m[5].clone() * m[2].clone() * m[11].clone()
164        - m[5].clone() * m[3].clone() * m[10].clone()
165        - m[9].clone() * m[2].clone() * m[7].clone()
166        + m[9].clone() * m[3].clone() * m[6].clone();
167
168    out[(0, 1)] = -m[4].clone() * m[10].clone() * m[15].clone()
169        + m[4].clone() * m[11].clone() * m[14].clone()
170        + m[8].clone() * m[6].clone() * m[15].clone()
171        - m[8].clone() * m[7].clone() * m[14].clone()
172        - m[12].clone() * m[6].clone() * m[11].clone()
173        + m[12].clone() * m[7].clone() * m[10].clone();
174
175    out[(1, 1)] = m[0].clone() * m[10].clone() * m[15].clone()
176        - m[0].clone() * m[11].clone() * m[14].clone()
177        - m[8].clone() * m[2].clone() * m[15].clone()
178        + m[8].clone() * m[3].clone() * m[14].clone()
179        + m[12].clone() * m[2].clone() * m[11].clone()
180        - m[12].clone() * m[3].clone() * m[10].clone();
181
182    out[(2, 1)] = -m[0].clone() * m[6].clone() * m[15].clone()
183        + m[0].clone() * m[7].clone() * m[14].clone()
184        + m[4].clone() * m[2].clone() * m[15].clone()
185        - m[4].clone() * m[3].clone() * m[14].clone()
186        - m[12].clone() * m[2].clone() * m[7].clone()
187        + m[12].clone() * m[3].clone() * m[6].clone();
188
189    out[(3, 1)] = m[0].clone() * m[6].clone() * m[11].clone()
190        - m[0].clone() * m[7].clone() * m[10].clone()
191        - m[4].clone() * m[2].clone() * m[11].clone()
192        + m[4].clone() * m[3].clone() * m[10].clone()
193        + m[8].clone() * m[2].clone() * m[7].clone()
194        - m[8].clone() * m[3].clone() * m[6].clone();
195
196    out[(0, 2)] = m[4].clone() * m[9].clone() * m[15].clone()
197        - m[4].clone() * m[11].clone() * m[13].clone()
198        - m[8].clone() * m[5].clone() * m[15].clone()
199        + m[8].clone() * m[7].clone() * m[13].clone()
200        + m[12].clone() * m[5].clone() * m[11].clone()
201        - m[12].clone() * m[7].clone() * m[9].clone();
202
203    out[(1, 2)] = -m[0].clone() * m[9].clone() * m[15].clone()
204        + m[0].clone() * m[11].clone() * m[13].clone()
205        + m[8].clone() * m[1].clone() * m[15].clone()
206        - m[8].clone() * m[3].clone() * m[13].clone()
207        - m[12].clone() * m[1].clone() * m[11].clone()
208        + m[12].clone() * m[3].clone() * m[9].clone();
209
210    out[(2, 2)] = m[0].clone() * m[5].clone() * m[15].clone()
211        - m[0].clone() * m[7].clone() * m[13].clone()
212        - m[4].clone() * m[1].clone() * m[15].clone()
213        + m[4].clone() * m[3].clone() * m[13].clone()
214        + m[12].clone() * m[1].clone() * m[7].clone()
215        - m[12].clone() * m[3].clone() * m[5].clone();
216
217    out[(0, 3)] = -m[4].clone() * m[9].clone() * m[14].clone()
218        + m[4].clone() * m[10].clone() * m[13].clone()
219        + m[8].clone() * m[5].clone() * m[14].clone()
220        - m[8].clone() * m[6].clone() * m[13].clone()
221        - m[12].clone() * m[5].clone() * m[10].clone()
222        + m[12].clone() * m[6].clone() * m[9].clone();
223
224    out[(3, 2)] = -m[0].clone() * m[5].clone() * m[11].clone()
225        + m[0].clone() * m[7].clone() * m[9].clone()
226        + m[4].clone() * m[1].clone() * m[11].clone()
227        - m[4].clone() * m[3].clone() * m[9].clone()
228        - m[8].clone() * m[1].clone() * m[7].clone()
229        + m[8].clone() * m[3].clone() * m[5].clone();
230
231    out[(1, 3)] = m[0].clone() * m[9].clone() * m[14].clone()
232        - m[0].clone() * m[10].clone() * m[13].clone()
233        - m[8].clone() * m[1].clone() * m[14].clone()
234        + m[8].clone() * m[2].clone() * m[13].clone()
235        + m[12].clone() * m[1].clone() * m[10].clone()
236        - m[12].clone() * m[2].clone() * m[9].clone();
237
238    out[(2, 3)] = -m[0].clone() * m[5].clone() * m[14].clone()
239        + m[0].clone() * m[6].clone() * m[13].clone()
240        + m[4].clone() * m[1].clone() * m[14].clone()
241        - m[4].clone() * m[2].clone() * m[13].clone()
242        - m[12].clone() * m[1].clone() * m[6].clone()
243        + m[12].clone() * m[2].clone() * m[5].clone();
244
245    out[(3, 3)] = m[0].clone() * m[5].clone() * m[10].clone()
246        - m[0].clone() * m[6].clone() * m[9].clone()
247        - m[4].clone() * m[1].clone() * m[10].clone()
248        + m[4].clone() * m[2].clone() * m[9].clone()
249        + m[8].clone() * m[1].clone() * m[6].clone()
250        - m[8].clone() * m[2].clone() * m[5].clone();
251
252    let det = m[0].clone() * out[(0, 0)].clone()
253        + m[1].clone() * out[(0, 1)].clone()
254        + m[2].clone() * out[(0, 2)].clone()
255        + m[3].clone() * out[(0, 3)].clone();
256
257    if !det.is_zero() {
258        let inv_det = T::one() / det;
259
260        for j in 0..4 {
261            for i in 0..4 {
262                out[(i, j)] *= inv_det.clone();
263            }
264        }
265        true
266    } else {
267        false
268    }
269}