1use simba::scalar::ComplexField;
2
3use crate::base::allocator::Allocator;
4use crate::base::dimension::Dim;
5use crate::base::storage::{Storage, StorageMut};
6use crate::base::{DefaultAllocator, OMatrix, SquareMatrix};
7
8use crate::linalg::lu;
9
10impl<T: ComplexField, D: Dim, S: Storage<T, D, D>> SquareMatrix<T, D, S> {
11 #[inline]
13 #[must_use = "Did you mean to use try_inverse_mut()?"]
14 pub fn try_inverse(self) -> Option<OMatrix<T, D, D>>
15 where
16 DefaultAllocator: Allocator<T, D, D>,
17 {
18 let mut me = self.into_owned();
19 if me.try_inverse_mut() {
20 Some(me)
21 } else {
22 None
23 }
24 }
25}
26
27impl<T: ComplexField, D: Dim, S: StorageMut<T, D, D>> SquareMatrix<T, D, S> {
28 #[inline]
31 pub fn try_inverse_mut(&mut self) -> bool
32 where
33 DefaultAllocator: Allocator<T, D, D>,
34 {
35 assert!(self.is_square(), "Unable to invert a non-square matrix.");
36
37 let dim = self.shape().0;
38
39 unsafe {
40 match dim {
41 0 => true,
42 1 => {
43 let determinant = self.get_unchecked((0, 0)).clone();
44 if determinant.is_zero() {
45 false
46 } else {
47 *self.get_unchecked_mut((0, 0)) = T::one() / determinant;
48 true
49 }
50 }
51 2 => {
52 let m11 = self.get_unchecked((0, 0)).clone();
53 let m12 = self.get_unchecked((0, 1)).clone();
54 let m21 = self.get_unchecked((1, 0)).clone();
55 let m22 = self.get_unchecked((1, 1)).clone();
56
57 let determinant = m11.clone() * m22.clone() - m21.clone() * m12.clone();
58
59 if determinant.is_zero() {
60 false
61 } else {
62 *self.get_unchecked_mut((0, 0)) = m22 / determinant.clone();
63 *self.get_unchecked_mut((0, 1)) = -m12 / determinant.clone();
64
65 *self.get_unchecked_mut((1, 0)) = -m21 / determinant.clone();
66 *self.get_unchecked_mut((1, 1)) = m11 / determinant;
67
68 true
69 }
70 }
71 3 => {
72 let m11 = self.get_unchecked((0, 0)).clone();
73 let m12 = self.get_unchecked((0, 1)).clone();
74 let m13 = self.get_unchecked((0, 2)).clone();
75
76 let m21 = self.get_unchecked((1, 0)).clone();
77 let m22 = self.get_unchecked((1, 1)).clone();
78 let m23 = self.get_unchecked((1, 2)).clone();
79
80 let m31 = self.get_unchecked((2, 0)).clone();
81 let m32 = self.get_unchecked((2, 1)).clone();
82 let m33 = self.get_unchecked((2, 2)).clone();
83
84 let minor_m12_m23 = m22.clone() * m33.clone() - m32.clone() * m23.clone();
85 let minor_m11_m23 = m21.clone() * m33.clone() - m31.clone() * m23.clone();
86 let minor_m11_m22 = m21.clone() * m32.clone() - m31.clone() * m22.clone();
87
88 let determinant = m11.clone() * minor_m12_m23.clone()
89 - m12.clone() * minor_m11_m23.clone()
90 + m13.clone() * minor_m11_m22.clone();
91
92 if determinant.is_zero() {
93 false
94 } else {
95 *self.get_unchecked_mut((0, 0)) = minor_m12_m23 / determinant.clone();
96 *self.get_unchecked_mut((0, 1)) = (m13.clone() * m32.clone()
97 - m33.clone() * m12.clone())
98 / determinant.clone();
99 *self.get_unchecked_mut((0, 2)) = (m12.clone() * m23.clone()
100 - m22.clone() * m13.clone())
101 / determinant.clone();
102
103 *self.get_unchecked_mut((1, 0)) = -minor_m11_m23 / determinant.clone();
104 *self.get_unchecked_mut((1, 1)) =
105 (m11.clone() * m33 - m31.clone() * m13.clone()) / determinant.clone();
106 *self.get_unchecked_mut((1, 2)) =
107 (m13 * m21.clone() - m23 * m11.clone()) / determinant.clone();
108
109 *self.get_unchecked_mut((2, 0)) = minor_m11_m22 / determinant.clone();
110 *self.get_unchecked_mut((2, 1)) =
111 (m12.clone() * m31 - m32 * m11.clone()) / determinant.clone();
112 *self.get_unchecked_mut((2, 2)) = (m11 * m22 - m21 * m12) / determinant;
113
114 true
115 }
116 }
117 4 => {
118 let oself = self.clone_owned();
119 do_inverse4(&oself, self)
120 }
121 _ => {
122 let oself = self.clone_owned();
123 lu::try_invert_to(oself, self)
124 }
125 }
126 }
127 }
128}
129
130fn do_inverse4<T: ComplexField, D: Dim, S: StorageMut<T, D, D>>(
132 m: &OMatrix<T, D, D>,
133 out: &mut SquareMatrix<T, D, S>,
134) -> bool
135where
136 DefaultAllocator: Allocator<T, D, D>,
137{
138 let m = m.as_slice();
139
140 out[(0, 0)] = m[5].clone() * m[10].clone() * m[15].clone()
141 - m[5].clone() * m[11].clone() * m[14].clone()
142 - m[9].clone() * m[6].clone() * m[15].clone()
143 + m[9].clone() * m[7].clone() * m[14].clone()
144 + m[13].clone() * m[6].clone() * m[11].clone()
145 - m[13].clone() * m[7].clone() * m[10].clone();
146
147 out[(1, 0)] = -m[1].clone() * m[10].clone() * m[15].clone()
148 + m[1].clone() * m[11].clone() * m[14].clone()
149 + m[9].clone() * m[2].clone() * m[15].clone()
150 - m[9].clone() * m[3].clone() * m[14].clone()
151 - m[13].clone() * m[2].clone() * m[11].clone()
152 + m[13].clone() * m[3].clone() * m[10].clone();
153
154 out[(2, 0)] = m[1].clone() * m[6].clone() * m[15].clone()
155 - m[1].clone() * m[7].clone() * m[14].clone()
156 - m[5].clone() * m[2].clone() * m[15].clone()
157 + m[5].clone() * m[3].clone() * m[14].clone()
158 + m[13].clone() * m[2].clone() * m[7].clone()
159 - m[13].clone() * m[3].clone() * m[6].clone();
160
161 out[(3, 0)] = -m[1].clone() * m[6].clone() * m[11].clone()
162 + m[1].clone() * m[7].clone() * m[10].clone()
163 + m[5].clone() * m[2].clone() * m[11].clone()
164 - m[5].clone() * m[3].clone() * m[10].clone()
165 - m[9].clone() * m[2].clone() * m[7].clone()
166 + m[9].clone() * m[3].clone() * m[6].clone();
167
168 out[(0, 1)] = -m[4].clone() * m[10].clone() * m[15].clone()
169 + m[4].clone() * m[11].clone() * m[14].clone()
170 + m[8].clone() * m[6].clone() * m[15].clone()
171 - m[8].clone() * m[7].clone() * m[14].clone()
172 - m[12].clone() * m[6].clone() * m[11].clone()
173 + m[12].clone() * m[7].clone() * m[10].clone();
174
175 out[(1, 1)] = m[0].clone() * m[10].clone() * m[15].clone()
176 - m[0].clone() * m[11].clone() * m[14].clone()
177 - m[8].clone() * m[2].clone() * m[15].clone()
178 + m[8].clone() * m[3].clone() * m[14].clone()
179 + m[12].clone() * m[2].clone() * m[11].clone()
180 - m[12].clone() * m[3].clone() * m[10].clone();
181
182 out[(2, 1)] = -m[0].clone() * m[6].clone() * m[15].clone()
183 + m[0].clone() * m[7].clone() * m[14].clone()
184 + m[4].clone() * m[2].clone() * m[15].clone()
185 - m[4].clone() * m[3].clone() * m[14].clone()
186 - m[12].clone() * m[2].clone() * m[7].clone()
187 + m[12].clone() * m[3].clone() * m[6].clone();
188
189 out[(3, 1)] = m[0].clone() * m[6].clone() * m[11].clone()
190 - m[0].clone() * m[7].clone() * m[10].clone()
191 - m[4].clone() * m[2].clone() * m[11].clone()
192 + m[4].clone() * m[3].clone() * m[10].clone()
193 + m[8].clone() * m[2].clone() * m[7].clone()
194 - m[8].clone() * m[3].clone() * m[6].clone();
195
196 out[(0, 2)] = m[4].clone() * m[9].clone() * m[15].clone()
197 - m[4].clone() * m[11].clone() * m[13].clone()
198 - m[8].clone() * m[5].clone() * m[15].clone()
199 + m[8].clone() * m[7].clone() * m[13].clone()
200 + m[12].clone() * m[5].clone() * m[11].clone()
201 - m[12].clone() * m[7].clone() * m[9].clone();
202
203 out[(1, 2)] = -m[0].clone() * m[9].clone() * m[15].clone()
204 + m[0].clone() * m[11].clone() * m[13].clone()
205 + m[8].clone() * m[1].clone() * m[15].clone()
206 - m[8].clone() * m[3].clone() * m[13].clone()
207 - m[12].clone() * m[1].clone() * m[11].clone()
208 + m[12].clone() * m[3].clone() * m[9].clone();
209
210 out[(2, 2)] = m[0].clone() * m[5].clone() * m[15].clone()
211 - m[0].clone() * m[7].clone() * m[13].clone()
212 - m[4].clone() * m[1].clone() * m[15].clone()
213 + m[4].clone() * m[3].clone() * m[13].clone()
214 + m[12].clone() * m[1].clone() * m[7].clone()
215 - m[12].clone() * m[3].clone() * m[5].clone();
216
217 out[(0, 3)] = -m[4].clone() * m[9].clone() * m[14].clone()
218 + m[4].clone() * m[10].clone() * m[13].clone()
219 + m[8].clone() * m[5].clone() * m[14].clone()
220 - m[8].clone() * m[6].clone() * m[13].clone()
221 - m[12].clone() * m[5].clone() * m[10].clone()
222 + m[12].clone() * m[6].clone() * m[9].clone();
223
224 out[(3, 2)] = -m[0].clone() * m[5].clone() * m[11].clone()
225 + m[0].clone() * m[7].clone() * m[9].clone()
226 + m[4].clone() * m[1].clone() * m[11].clone()
227 - m[4].clone() * m[3].clone() * m[9].clone()
228 - m[8].clone() * m[1].clone() * m[7].clone()
229 + m[8].clone() * m[3].clone() * m[5].clone();
230
231 out[(1, 3)] = m[0].clone() * m[9].clone() * m[14].clone()
232 - m[0].clone() * m[10].clone() * m[13].clone()
233 - m[8].clone() * m[1].clone() * m[14].clone()
234 + m[8].clone() * m[2].clone() * m[13].clone()
235 + m[12].clone() * m[1].clone() * m[10].clone()
236 - m[12].clone() * m[2].clone() * m[9].clone();
237
238 out[(2, 3)] = -m[0].clone() * m[5].clone() * m[14].clone()
239 + m[0].clone() * m[6].clone() * m[13].clone()
240 + m[4].clone() * m[1].clone() * m[14].clone()
241 - m[4].clone() * m[2].clone() * m[13].clone()
242 - m[12].clone() * m[1].clone() * m[6].clone()
243 + m[12].clone() * m[2].clone() * m[5].clone();
244
245 out[(3, 3)] = m[0].clone() * m[5].clone() * m[10].clone()
246 - m[0].clone() * m[6].clone() * m[9].clone()
247 - m[4].clone() * m[1].clone() * m[10].clone()
248 + m[4].clone() * m[2].clone() * m[9].clone()
249 + m[8].clone() * m[1].clone() * m[6].clone()
250 - m[8].clone() * m[2].clone() * m[5].clone();
251
252 let det = m[0].clone() * out[(0, 0)].clone()
253 + m[1].clone() * out[(0, 1)].clone()
254 + m[2].clone() * out[(0, 2)].clone()
255 + m[3].clone() * out[(0, 3)].clone();
256
257 if !det.is_zero() {
258 let inv_det = T::one() / det;
259
260 for j in 0..4 {
261 for i in 0..4 {
262 out[(i, j)] *= inv_det.clone();
263 }
264 }
265 true
266 } else {
267 false
268 }
269}