nalgebra/base/
array_storage.rs

1use std::fmt::{self, Debug, Formatter};
2// use std::hash::{Hash, Hasher};
3#[cfg(feature = "abomonation-serialize")]
4use std::io::{Result as IOResult, Write};
5use std::ops::Mul;
6
7#[cfg(feature = "serde-serialize-no-std")]
8use serde::de::{Error, SeqAccess, Visitor};
9#[cfg(feature = "serde-serialize-no-std")]
10use serde::ser::SerializeSeq;
11#[cfg(feature = "serde-serialize-no-std")]
12use serde::{Deserialize, Deserializer, Serialize, Serializer};
13#[cfg(feature = "serde-serialize-no-std")]
14use std::marker::PhantomData;
15
16#[cfg(feature = "abomonation-serialize")]
17use abomonation::Abomonation;
18
19use crate::base::allocator::Allocator;
20use crate::base::default_allocator::DefaultAllocator;
21use crate::base::dimension::{Const, ToTypenum};
22use crate::base::storage::{IsContiguous, Owned, RawStorage, RawStorageMut, ReshapableStorage};
23use crate::base::Scalar;
24use crate::Storage;
25use std::mem;
26
27/*
28 *
29 * Static RawStorage.
30 *
31 */
32/// A array-based statically sized matrix data storage.
33#[repr(transparent)]
34#[derive(Copy, Clone, PartialEq, Eq, Hash)]
35#[cfg_attr(
36    all(not(target_os = "cuda"), feature = "cuda"),
37    derive(cust::DeviceCopy)
38)]
39pub struct ArrayStorage<T, const R: usize, const C: usize>(pub [[T; R]; C]);
40
41impl<T, const R: usize, const C: usize> ArrayStorage<T, R, C> {
42    /// Converts this array storage to a slice.
43    #[inline]
44    pub fn as_slice(&self) -> &[T] {
45        // SAFETY: this is OK because ArrayStorage is contiguous.
46        unsafe { self.as_slice_unchecked() }
47    }
48
49    /// Converts this array storage to a mutable slice.
50    #[inline]
51    pub fn as_mut_slice(&mut self) -> &mut [T] {
52        // SAFETY: this is OK because ArrayStorage is contiguous.
53        unsafe { self.as_mut_slice_unchecked() }
54    }
55}
56
57// TODO: remove this once the stdlib implements Default for arrays.
58impl<T: Default, const R: usize, const C: usize> Default for ArrayStorage<T, R, C>
59where
60    [[T; R]; C]: Default,
61{
62    #[inline]
63    fn default() -> Self {
64        Self(Default::default())
65    }
66}
67
68impl<T: Debug, const R: usize, const C: usize> Debug for ArrayStorage<T, R, C> {
69    #[inline]
70    fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
71        self.0.fmt(fmt)
72    }
73}
74
75unsafe impl<T, const R: usize, const C: usize> RawStorage<T, Const<R>, Const<C>>
76    for ArrayStorage<T, R, C>
77{
78    type RStride = Const<1>;
79    type CStride = Const<R>;
80
81    #[inline]
82    fn ptr(&self) -> *const T {
83        self.0.as_ptr() as *const T
84    }
85
86    #[inline]
87    fn shape(&self) -> (Const<R>, Const<C>) {
88        (Const, Const)
89    }
90
91    #[inline]
92    fn strides(&self) -> (Self::RStride, Self::CStride) {
93        (Const, Const)
94    }
95
96    #[inline]
97    fn is_contiguous(&self) -> bool {
98        true
99    }
100
101    #[inline]
102    unsafe fn as_slice_unchecked(&self) -> &[T] {
103        std::slice::from_raw_parts(self.ptr(), R * C)
104    }
105}
106
107unsafe impl<T: Scalar, const R: usize, const C: usize> Storage<T, Const<R>, Const<C>>
108    for ArrayStorage<T, R, C>
109where
110    DefaultAllocator: Allocator<T, Const<R>, Const<C>, Buffer = Self>,
111{
112    #[inline]
113    fn into_owned(self) -> Owned<T, Const<R>, Const<C>>
114    where
115        DefaultAllocator: Allocator<T, Const<R>, Const<C>>,
116    {
117        self
118    }
119
120    #[inline]
121    fn clone_owned(&self) -> Owned<T, Const<R>, Const<C>>
122    where
123        DefaultAllocator: Allocator<T, Const<R>, Const<C>>,
124    {
125        self.clone()
126    }
127}
128
129unsafe impl<T, const R: usize, const C: usize> RawStorageMut<T, Const<R>, Const<C>>
130    for ArrayStorage<T, R, C>
131{
132    #[inline]
133    fn ptr_mut(&mut self) -> *mut T {
134        self.0.as_mut_ptr() as *mut T
135    }
136
137    #[inline]
138    unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [T] {
139        std::slice::from_raw_parts_mut(self.ptr_mut(), R * C)
140    }
141}
142
143unsafe impl<T, const R: usize, const C: usize> IsContiguous for ArrayStorage<T, R, C> {}
144
145impl<T, const R1: usize, const C1: usize, const R2: usize, const C2: usize>
146    ReshapableStorage<T, Const<R1>, Const<C1>, Const<R2>, Const<C2>> for ArrayStorage<T, R1, C1>
147where
148    T: Scalar,
149    Const<R1>: ToTypenum,
150    Const<C1>: ToTypenum,
151    Const<R2>: ToTypenum,
152    Const<C2>: ToTypenum,
153    <Const<R1> as ToTypenum>::Typenum: Mul<<Const<C1> as ToTypenum>::Typenum>,
154    <Const<R2> as ToTypenum>::Typenum: Mul<
155        <Const<C2> as ToTypenum>::Typenum,
156        Output = typenum::Prod<
157            <Const<R1> as ToTypenum>::Typenum,
158            <Const<C1> as ToTypenum>::Typenum,
159        >,
160    >,
161{
162    type Output = ArrayStorage<T, R2, C2>;
163
164    fn reshape_generic(self, _: Const<R2>, _: Const<C2>) -> Self::Output {
165        unsafe {
166            let data: [[T; R2]; C2] = mem::transmute_copy(&self.0);
167            mem::forget(self.0);
168            ArrayStorage(data)
169        }
170    }
171}
172
173/*
174 *
175 * Serialization.
176 *
177 */
178// XXX: open an issue for serde so that it allows the serialization/deserialization of all arrays?
179#[cfg(feature = "serde-serialize-no-std")]
180impl<T, const R: usize, const C: usize> Serialize for ArrayStorage<T, R, C>
181where
182    T: Scalar + Serialize,
183{
184    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
185    where
186        S: Serializer,
187    {
188        let mut serializer = serializer.serialize_seq(Some(R * C))?;
189
190        for e in self.as_slice().iter() {
191            serializer.serialize_element(e)?;
192        }
193
194        serializer.end()
195    }
196}
197
198#[cfg(feature = "serde-serialize-no-std")]
199impl<'a, T, const R: usize, const C: usize> Deserialize<'a> for ArrayStorage<T, R, C>
200where
201    T: Scalar + Deserialize<'a>,
202{
203    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
204    where
205        D: Deserializer<'a>,
206    {
207        deserializer.deserialize_seq(ArrayStorageVisitor::new())
208    }
209}
210
211#[cfg(feature = "serde-serialize-no-std")]
212/// A visitor that produces a matrix array.
213struct ArrayStorageVisitor<T, const R: usize, const C: usize> {
214    marker: PhantomData<T>,
215}
216
217#[cfg(feature = "serde-serialize-no-std")]
218impl<T, const R: usize, const C: usize> ArrayStorageVisitor<T, R, C>
219where
220    T: Scalar,
221{
222    /// Construct a new sequence visitor.
223    pub fn new() -> Self {
224        ArrayStorageVisitor {
225            marker: PhantomData,
226        }
227    }
228}
229
230#[cfg(feature = "serde-serialize-no-std")]
231impl<'a, T, const R: usize, const C: usize> Visitor<'a> for ArrayStorageVisitor<T, R, C>
232where
233    T: Scalar + Deserialize<'a>,
234{
235    type Value = ArrayStorage<T, R, C>;
236
237    fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
238        formatter.write_str("a matrix array")
239    }
240
241    #[inline]
242    fn visit_seq<V>(self, mut visitor: V) -> Result<ArrayStorage<T, R, C>, V::Error>
243    where
244        V: SeqAccess<'a>,
245    {
246        let mut out: ArrayStorage<core::mem::MaybeUninit<T>, R, C> =
247            DefaultAllocator::allocate_uninit(Const::<R>, Const::<C>);
248        let mut curr = 0;
249
250        while let Some(value) = visitor.next_element()? {
251            *out.as_mut_slice()
252                .get_mut(curr)
253                .ok_or_else(|| V::Error::invalid_length(curr, &self))? =
254                core::mem::MaybeUninit::new(value);
255            curr += 1;
256        }
257
258        if curr == R * C {
259            // Safety: all the elements have been initialized.
260            unsafe { Ok(<DefaultAllocator as Allocator<T, Const<R>, Const<C>>>::assume_init(out)) }
261        } else {
262            for i in 0..curr {
263                // Safety:
264                // - We couldn’t initialize the whole storage. Drop the ones we initialized.
265                unsafe { std::ptr::drop_in_place(out.as_mut_slice()[i].as_mut_ptr()) };
266            }
267
268            Err(V::Error::invalid_length(curr, &self))
269        }
270    }
271}
272
273#[cfg(feature = "bytemuck")]
274unsafe impl<T: Scalar + Copy + bytemuck::Zeroable, const R: usize, const C: usize>
275    bytemuck::Zeroable for ArrayStorage<T, R, C>
276{
277}
278
279#[cfg(feature = "bytemuck")]
280unsafe impl<T: Scalar + Copy + bytemuck::Pod, const R: usize, const C: usize> bytemuck::Pod
281    for ArrayStorage<T, R, C>
282{
283}
284
285#[cfg(feature = "abomonation-serialize")]
286impl<T, const R: usize, const C: usize> Abomonation for ArrayStorage<T, R, C>
287where
288    T: Scalar + Abomonation,
289{
290    unsafe fn entomb<W: Write>(&self, writer: &mut W) -> IOResult<()> {
291        for element in self.as_slice() {
292            element.entomb(writer)?;
293        }
294
295        Ok(())
296    }
297
298    unsafe fn exhume<'a, 'b>(&'a mut self, mut bytes: &'b mut [u8]) -> Option<&'b mut [u8]> {
299        for element in self.as_mut_slice() {
300            let temp = bytes;
301            bytes = element.exhume(temp)?
302        }
303        Some(bytes)
304    }
305
306    fn extent(&self) -> usize {
307        self.as_slice().iter().fold(0, |acc, e| acc + e.extent())
308    }
309}
310
311#[cfg(feature = "rkyv-serialize-no-std")]
312mod rkyv_impl {
313    use super::ArrayStorage;
314    use rkyv::{offset_of, project_struct, Archive, Deserialize, Fallible, Serialize};
315
316    impl<T: Archive, const R: usize, const C: usize> Archive for ArrayStorage<T, R, C> {
317        type Archived = ArrayStorage<T::Archived, R, C>;
318        type Resolver = <[[T; R]; C] as Archive>::Resolver;
319
320        fn resolve(
321            &self,
322            pos: usize,
323            resolver: Self::Resolver,
324            out: &mut core::mem::MaybeUninit<Self::Archived>,
325        ) {
326            self.0.resolve(
327                pos + offset_of!(Self::Archived, 0),
328                resolver,
329                project_struct!(out: Self::Archived => 0),
330            );
331        }
332    }
333
334    impl<T: Serialize<S>, S: Fallible + ?Sized, const R: usize, const C: usize> Serialize<S>
335        for ArrayStorage<T, R, C>
336    {
337        fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
338            self.0.serialize(serializer)
339        }
340    }
341
342    impl<T: Archive, D: Fallible + ?Sized, const R: usize, const C: usize>
343        Deserialize<ArrayStorage<T, R, C>, D> for ArrayStorage<T::Archived, R, C>
344    where
345        T::Archived: Deserialize<T, D>,
346    {
347        fn deserialize(&self, deserializer: &mut D) -> Result<ArrayStorage<T, R, C>, D::Error> {
348            Ok(ArrayStorage(self.0.deserialize(deserializer)?))
349        }
350    }
351}