1use std::fmt;
2#[cfg(feature = "abomonation-serialize")]
3use std::io::{Result as IOResult, Write};
4use std::ops::Deref;
5
6#[cfg(feature = "serde-serialize-no-std")]
7use serde::{Deserialize, Deserializer, Serialize, Serializer};
8
9#[cfg(feature = "abomonation-serialize")]
10use abomonation::Abomonation;
11
12use crate::allocator::Allocator;
13use crate::base::DefaultAllocator;
14use crate::storage::RawStorage;
15use crate::{Dim, Matrix, OMatrix, RealField, Scalar, SimdComplexField, SimdRealField};
16
17#[repr(transparent)]
28#[derive(Clone, Hash, Copy)]
29pub struct Unit<T> {
34 pub(crate) value: T,
35}
36
37impl<T: fmt::Debug> fmt::Debug for Unit<T> {
38 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
39 self.value.fmt(formatter)
40 }
41}
42
43#[cfg(feature = "bytemuck")]
44unsafe impl<T> bytemuck::Zeroable for Unit<T> where T: bytemuck::Zeroable {}
45
46#[cfg(feature = "bytemuck")]
47unsafe impl<T> bytemuck::Pod for Unit<T> where T: bytemuck::Pod {}
48
49#[cfg(feature = "serde-serialize-no-std")]
50impl<T: Serialize> Serialize for Unit<T> {
51 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
52 where
53 S: Serializer,
54 {
55 self.value.serialize(serializer)
56 }
57}
58
59#[cfg(feature = "serde-serialize-no-std")]
60impl<'de, T: Deserialize<'de>> Deserialize<'de> for Unit<T> {
61 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
62 where
63 D: Deserializer<'de>,
64 {
65 T::deserialize(deserializer).map(|x| Unit { value: x })
66 }
67}
68
69#[cfg(feature = "abomonation-serialize")]
70impl<T: Abomonation> Abomonation for Unit<T> {
71 unsafe fn entomb<W: Write>(&self, writer: &mut W) -> IOResult<()> {
72 self.value.entomb(writer)
73 }
74
75 fn extent(&self) -> usize {
76 self.value.extent()
77 }
78
79 unsafe fn exhume<'a, 'b>(&'a mut self, bytes: &'b mut [u8]) -> Option<&'b mut [u8]> {
80 self.value.exhume(bytes)
81 }
82}
83
84#[cfg(feature = "rkyv-serialize-no-std")]
85mod rkyv_impl {
86 use super::Unit;
87 use rkyv::{offset_of, project_struct, Archive, Deserialize, Fallible, Serialize};
88
89 impl<T: Archive> Archive for Unit<T> {
90 type Archived = Unit<T::Archived>;
91 type Resolver = T::Resolver;
92
93 fn resolve(
94 &self,
95 pos: usize,
96 resolver: Self::Resolver,
97 out: &mut ::core::mem::MaybeUninit<Self::Archived>,
98 ) {
99 self.value.resolve(
100 pos + offset_of!(Self::Archived, value),
101 resolver,
102 project_struct!(out: Self::Archived => value),
103 );
104 }
105 }
106
107 impl<T: Serialize<S>, S: Fallible + ?Sized> Serialize<S> for Unit<T> {
108 fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
109 self.value.serialize(serializer)
110 }
111 }
112
113 impl<T: Archive, D: Fallible + ?Sized> Deserialize<Unit<T>, D> for Unit<T::Archived>
114 where
115 T::Archived: Deserialize<T, D>,
116 {
117 fn deserialize(&self, deserializer: &mut D) -> Result<Unit<T>, D::Error> {
118 Ok(Unit {
119 value: self.value.deserialize(deserializer)?,
120 })
121 }
122 }
123}
124
125#[cfg(all(not(target_os = "cuda"), feature = "cuda"))]
126unsafe impl<T: cust::memory::DeviceCopy, R, C, S> cust::memory::DeviceCopy
127 for Unit<Matrix<T, R, C, S>>
128where
129 T: Scalar,
130 R: Dim,
131 C: Dim,
132 S: RawStorage<T, R, C> + Copy,
133{
134}
135
136impl<T, R, C, S> PartialEq for Unit<Matrix<T, R, C, S>>
137where
138 T: Scalar + PartialEq,
139 R: Dim,
140 C: Dim,
141 S: RawStorage<T, R, C>,
142{
143 #[inline]
144 fn eq(&self, rhs: &Self) -> bool {
145 self.value.eq(&rhs.value)
146 }
147}
148
149impl<T, R, C, S> Eq for Unit<Matrix<T, R, C, S>>
150where
151 T: Scalar + Eq,
152 R: Dim,
153 C: Dim,
154 S: RawStorage<T, R, C>,
155{
156}
157
158pub trait Normed {
160 type Norm: SimdRealField;
162 fn norm(&self) -> Self::Norm;
164 fn norm_squared(&self) -> Self::Norm;
166 fn scale_mut(&mut self, n: Self::Norm);
168 fn unscale_mut(&mut self, n: Self::Norm);
170}
171
172impl<T: Normed> Unit<T> {
174 #[inline]
176 pub fn new_normalize(value: T) -> Self {
177 Self::new_and_get(value).0
178 }
179
180 #[inline]
184 pub fn try_new(value: T, min_norm: T::Norm) -> Option<Self>
185 where
186 T::Norm: RealField,
187 {
188 Self::try_new_and_get(value, min_norm).map(|res| res.0)
189 }
190
191 #[inline]
193 pub fn new_and_get(mut value: T) -> (Self, T::Norm) {
194 let n = value.norm();
195 value.unscale_mut(n.clone());
196 (Unit { value }, n)
197 }
198
199 #[inline]
203 pub fn try_new_and_get(mut value: T, min_norm: T::Norm) -> Option<(Self, T::Norm)>
204 where
205 T::Norm: RealField,
206 {
207 let sq_norm = value.norm_squared();
208
209 if sq_norm > min_norm.clone() * min_norm {
210 let n = sq_norm.simd_sqrt();
211 value.unscale_mut(n.clone());
212 Some((Unit { value }, n))
213 } else {
214 None
215 }
216 }
217
218 #[inline]
224 pub fn renormalize(&mut self) -> T::Norm {
225 let n = self.norm();
226 self.value.unscale_mut(n.clone());
227 n
228 }
229
230 #[inline]
234 pub fn renormalize_fast(&mut self) {
235 let sq_norm = self.value.norm_squared();
236 let three: T::Norm = crate::convert(3.0);
237 let half: T::Norm = crate::convert(0.5);
238 self.value.scale_mut(half * (three - sq_norm));
239 }
240}
241
242impl<T> Unit<T> {
244 #[inline]
246 pub const fn new_unchecked(value: T) -> Self {
247 Unit { value }
248 }
249
250 #[inline]
252 pub fn from_ref_unchecked(value: &T) -> &Self {
253 unsafe { &*(value as *const T as *const Self) }
254 }
255
256 #[inline]
258 pub fn into_inner(self) -> T {
259 self.value
260 }
261
262 #[deprecated(note = "use `.into_inner()` instead")]
265 #[inline]
266 pub fn unwrap(self) -> T {
267 self.value
268 }
269
270 #[inline]
274 pub fn as_mut_unchecked(&mut self) -> &mut T {
275 &mut self.value
276 }
277}
278
279impl<T> AsRef<T> for Unit<T> {
280 #[inline]
281 fn as_ref(&self) -> &T {
282 &self.value
283 }
284}
285
286impl<T> Deref for Unit<T> {
352 type Target = T;
353
354 #[inline]
355 fn deref(&self) -> &T {
356 unsafe { &*(self as *const Self as *const T) }
357 }
358}
359
360impl<T: Scalar + simba::simd::PrimitiveSimdValue, R: Dim, C: Dim>
364 From<[Unit<OMatrix<T::Element, R, C>>; 2]> for Unit<OMatrix<T, R, C>>
365where
366 T: From<[<T as simba::simd::SimdValue>::Element; 2]>,
367 T::Element: Scalar,
368 DefaultAllocator: Allocator<T, R, C> + Allocator<T::Element, R, C>,
369{
370 #[inline]
371 fn from(arr: [Unit<OMatrix<T::Element, R, C>>; 2]) -> Self {
372 Self::new_unchecked(OMatrix::from([
373 arr[0].clone().into_inner(),
374 arr[1].clone().into_inner(),
375 ]))
376 }
377}
378
379impl<T: Scalar + simba::simd::PrimitiveSimdValue, R: Dim, C: Dim>
380 From<[Unit<OMatrix<T::Element, R, C>>; 4]> for Unit<OMatrix<T, R, C>>
381where
382 T: From<[<T as simba::simd::SimdValue>::Element; 4]>,
383 T::Element: Scalar,
384 DefaultAllocator: Allocator<T, R, C> + Allocator<T::Element, R, C>,
385{
386 #[inline]
387 fn from(arr: [Unit<OMatrix<T::Element, R, C>>; 4]) -> Self {
388 Self::new_unchecked(OMatrix::from([
389 arr[0].clone().into_inner(),
390 arr[1].clone().into_inner(),
391 arr[2].clone().into_inner(),
392 arr[3].clone().into_inner(),
393 ]))
394 }
395}
396
397impl<T: Scalar + simba::simd::PrimitiveSimdValue, R: Dim, C: Dim>
398 From<[Unit<OMatrix<T::Element, R, C>>; 8]> for Unit<OMatrix<T, R, C>>
399where
400 T: From<[<T as simba::simd::SimdValue>::Element; 8]>,
401 T::Element: Scalar,
402 DefaultAllocator: Allocator<T, R, C> + Allocator<T::Element, R, C>,
403{
404 #[inline]
405 fn from(arr: [Unit<OMatrix<T::Element, R, C>>; 8]) -> Self {
406 Self::new_unchecked(OMatrix::from([
407 arr[0].clone().into_inner(),
408 arr[1].clone().into_inner(),
409 arr[2].clone().into_inner(),
410 arr[3].clone().into_inner(),
411 arr[4].clone().into_inner(),
412 arr[5].clone().into_inner(),
413 arr[6].clone().into_inner(),
414 arr[7].clone().into_inner(),
415 ]))
416 }
417}
418
419impl<T: Scalar + simba::simd::PrimitiveSimdValue, R: Dim, C: Dim>
420 From<[Unit<OMatrix<T::Element, R, C>>; 16]> for Unit<OMatrix<T, R, C>>
421where
422 T: From<[<T as simba::simd::SimdValue>::Element; 16]>,
423 T::Element: Scalar,
424 DefaultAllocator: Allocator<T, R, C> + Allocator<T::Element, R, C>,
425{
426 #[inline]
427 fn from(arr: [Unit<OMatrix<T::Element, R, C>>; 16]) -> Self {
428 Self::new_unchecked(OMatrix::from([
429 arr[0].clone().into_inner(),
430 arr[1].clone().into_inner(),
431 arr[2].clone().into_inner(),
432 arr[3].clone().into_inner(),
433 arr[4].clone().into_inner(),
434 arr[5].clone().into_inner(),
435 arr[6].clone().into_inner(),
436 arr[7].clone().into_inner(),
437 arr[8].clone().into_inner(),
438 arr[9].clone().into_inner(),
439 arr[10].clone().into_inner(),
440 arr[11].clone().into_inner(),
441 arr[12].clone().into_inner(),
442 arr[13].clone().into_inner(),
443 arr[14].clone().into_inner(),
444 arr[15].clone().into_inner(),
445 ]))
446 }
447}