1use std::fmt::{self, Debug, Formatter};
2#[cfg(feature = "abomonation-serialize")]
4use std::io::{Result as IOResult, Write};
5use std::ops::Mul;
6
7#[cfg(feature = "serde-serialize-no-std")]
8use serde::de::{Error, SeqAccess, Visitor};
9#[cfg(feature = "serde-serialize-no-std")]
10use serde::ser::SerializeSeq;
11#[cfg(feature = "serde-serialize-no-std")]
12use serde::{Deserialize, Deserializer, Serialize, Serializer};
13#[cfg(feature = "serde-serialize-no-std")]
14use std::marker::PhantomData;
15
16#[cfg(feature = "abomonation-serialize")]
17use abomonation::Abomonation;
18
19use crate::base::allocator::Allocator;
20use crate::base::default_allocator::DefaultAllocator;
21use crate::base::dimension::{Const, ToTypenum};
22use crate::base::storage::{IsContiguous, Owned, RawStorage, RawStorageMut, ReshapableStorage};
23use crate::base::Scalar;
24use crate::Storage;
25use std::mem;
26
27#[repr(transparent)]
34#[derive(Copy, Clone, PartialEq, Eq, Hash)]
35#[cfg_attr(
36 all(not(target_os = "cuda"), feature = "cuda"),
37 derive(cust::DeviceCopy)
38)]
39pub struct ArrayStorage<T, const R: usize, const C: usize>(pub [[T; R]; C]);
40
41impl<T, const R: usize, const C: usize> ArrayStorage<T, R, C> {
42 #[inline]
44 pub fn as_slice(&self) -> &[T] {
45 unsafe { self.as_slice_unchecked() }
47 }
48
49 #[inline]
51 pub fn as_mut_slice(&mut self) -> &mut [T] {
52 unsafe { self.as_mut_slice_unchecked() }
54 }
55}
56
57impl<T: Default, const R: usize, const C: usize> Default for ArrayStorage<T, R, C>
59where
60 [[T; R]; C]: Default,
61{
62 #[inline]
63 fn default() -> Self {
64 Self(Default::default())
65 }
66}
67
68impl<T: Debug, const R: usize, const C: usize> Debug for ArrayStorage<T, R, C> {
69 #[inline]
70 fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
71 self.0.fmt(fmt)
72 }
73}
74
75unsafe impl<T, const R: usize, const C: usize> RawStorage<T, Const<R>, Const<C>>
76 for ArrayStorage<T, R, C>
77{
78 type RStride = Const<1>;
79 type CStride = Const<R>;
80
81 #[inline]
82 fn ptr(&self) -> *const T {
83 self.0.as_ptr() as *const T
84 }
85
86 #[inline]
87 fn shape(&self) -> (Const<R>, Const<C>) {
88 (Const, Const)
89 }
90
91 #[inline]
92 fn strides(&self) -> (Self::RStride, Self::CStride) {
93 (Const, Const)
94 }
95
96 #[inline]
97 fn is_contiguous(&self) -> bool {
98 true
99 }
100
101 #[inline]
102 unsafe fn as_slice_unchecked(&self) -> &[T] {
103 std::slice::from_raw_parts(self.ptr(), R * C)
104 }
105}
106
107unsafe impl<T: Scalar, const R: usize, const C: usize> Storage<T, Const<R>, Const<C>>
108 for ArrayStorage<T, R, C>
109where
110 DefaultAllocator: Allocator<T, Const<R>, Const<C>, Buffer = Self>,
111{
112 #[inline]
113 fn into_owned(self) -> Owned<T, Const<R>, Const<C>>
114 where
115 DefaultAllocator: Allocator<T, Const<R>, Const<C>>,
116 {
117 self
118 }
119
120 #[inline]
121 fn clone_owned(&self) -> Owned<T, Const<R>, Const<C>>
122 where
123 DefaultAllocator: Allocator<T, Const<R>, Const<C>>,
124 {
125 self.clone()
126 }
127}
128
129unsafe impl<T, const R: usize, const C: usize> RawStorageMut<T, Const<R>, Const<C>>
130 for ArrayStorage<T, R, C>
131{
132 #[inline]
133 fn ptr_mut(&mut self) -> *mut T {
134 self.0.as_mut_ptr() as *mut T
135 }
136
137 #[inline]
138 unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [T] {
139 std::slice::from_raw_parts_mut(self.ptr_mut(), R * C)
140 }
141}
142
143unsafe impl<T, const R: usize, const C: usize> IsContiguous for ArrayStorage<T, R, C> {}
144
145impl<T, const R1: usize, const C1: usize, const R2: usize, const C2: usize>
146 ReshapableStorage<T, Const<R1>, Const<C1>, Const<R2>, Const<C2>> for ArrayStorage<T, R1, C1>
147where
148 T: Scalar,
149 Const<R1>: ToTypenum,
150 Const<C1>: ToTypenum,
151 Const<R2>: ToTypenum,
152 Const<C2>: ToTypenum,
153 <Const<R1> as ToTypenum>::Typenum: Mul<<Const<C1> as ToTypenum>::Typenum>,
154 <Const<R2> as ToTypenum>::Typenum: Mul<
155 <Const<C2> as ToTypenum>::Typenum,
156 Output = typenum::Prod<
157 <Const<R1> as ToTypenum>::Typenum,
158 <Const<C1> as ToTypenum>::Typenum,
159 >,
160 >,
161{
162 type Output = ArrayStorage<T, R2, C2>;
163
164 fn reshape_generic(self, _: Const<R2>, _: Const<C2>) -> Self::Output {
165 unsafe {
166 let data: [[T; R2]; C2] = mem::transmute_copy(&self.0);
167 mem::forget(self.0);
168 ArrayStorage(data)
169 }
170 }
171}
172
173#[cfg(feature = "serde-serialize-no-std")]
180impl<T, const R: usize, const C: usize> Serialize for ArrayStorage<T, R, C>
181where
182 T: Scalar + Serialize,
183{
184 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
185 where
186 S: Serializer,
187 {
188 let mut serializer = serializer.serialize_seq(Some(R * C))?;
189
190 for e in self.as_slice().iter() {
191 serializer.serialize_element(e)?;
192 }
193
194 serializer.end()
195 }
196}
197
198#[cfg(feature = "serde-serialize-no-std")]
199impl<'a, T, const R: usize, const C: usize> Deserialize<'a> for ArrayStorage<T, R, C>
200where
201 T: Scalar + Deserialize<'a>,
202{
203 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
204 where
205 D: Deserializer<'a>,
206 {
207 deserializer.deserialize_seq(ArrayStorageVisitor::new())
208 }
209}
210
211#[cfg(feature = "serde-serialize-no-std")]
212struct ArrayStorageVisitor<T, const R: usize, const C: usize> {
214 marker: PhantomData<T>,
215}
216
217#[cfg(feature = "serde-serialize-no-std")]
218impl<T, const R: usize, const C: usize> ArrayStorageVisitor<T, R, C>
219where
220 T: Scalar,
221{
222 pub fn new() -> Self {
224 ArrayStorageVisitor {
225 marker: PhantomData,
226 }
227 }
228}
229
230#[cfg(feature = "serde-serialize-no-std")]
231impl<'a, T, const R: usize, const C: usize> Visitor<'a> for ArrayStorageVisitor<T, R, C>
232where
233 T: Scalar + Deserialize<'a>,
234{
235 type Value = ArrayStorage<T, R, C>;
236
237 fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
238 formatter.write_str("a matrix array")
239 }
240
241 #[inline]
242 fn visit_seq<V>(self, mut visitor: V) -> Result<ArrayStorage<T, R, C>, V::Error>
243 where
244 V: SeqAccess<'a>,
245 {
246 let mut out: ArrayStorage<core::mem::MaybeUninit<T>, R, C> =
247 DefaultAllocator::allocate_uninit(Const::<R>, Const::<C>);
248 let mut curr = 0;
249
250 while let Some(value) = visitor.next_element()? {
251 *out.as_mut_slice()
252 .get_mut(curr)
253 .ok_or_else(|| V::Error::invalid_length(curr, &self))? =
254 core::mem::MaybeUninit::new(value);
255 curr += 1;
256 }
257
258 if curr == R * C {
259 unsafe { Ok(<DefaultAllocator as Allocator<T, Const<R>, Const<C>>>::assume_init(out)) }
261 } else {
262 for i in 0..curr {
263 unsafe { std::ptr::drop_in_place(out.as_mut_slice()[i].as_mut_ptr()) };
266 }
267
268 Err(V::Error::invalid_length(curr, &self))
269 }
270 }
271}
272
273#[cfg(feature = "bytemuck")]
274unsafe impl<T: Scalar + Copy + bytemuck::Zeroable, const R: usize, const C: usize>
275 bytemuck::Zeroable for ArrayStorage<T, R, C>
276{
277}
278
279#[cfg(feature = "bytemuck")]
280unsafe impl<T: Scalar + Copy + bytemuck::Pod, const R: usize, const C: usize> bytemuck::Pod
281 for ArrayStorage<T, R, C>
282{
283}
284
285#[cfg(feature = "abomonation-serialize")]
286impl<T, const R: usize, const C: usize> Abomonation for ArrayStorage<T, R, C>
287where
288 T: Scalar + Abomonation,
289{
290 unsafe fn entomb<W: Write>(&self, writer: &mut W) -> IOResult<()> {
291 for element in self.as_slice() {
292 element.entomb(writer)?;
293 }
294
295 Ok(())
296 }
297
298 unsafe fn exhume<'a, 'b>(&'a mut self, mut bytes: &'b mut [u8]) -> Option<&'b mut [u8]> {
299 for element in self.as_mut_slice() {
300 let temp = bytes;
301 bytes = element.exhume(temp)?
302 }
303 Some(bytes)
304 }
305
306 fn extent(&self) -> usize {
307 self.as_slice().iter().fold(0, |acc, e| acc + e.extent())
308 }
309}
310
311#[cfg(feature = "rkyv-serialize-no-std")]
312mod rkyv_impl {
313 use super::ArrayStorage;
314 use rkyv::{offset_of, project_struct, Archive, Deserialize, Fallible, Serialize};
315
316 impl<T: Archive, const R: usize, const C: usize> Archive for ArrayStorage<T, R, C> {
317 type Archived = ArrayStorage<T::Archived, R, C>;
318 type Resolver = <[[T; R]; C] as Archive>::Resolver;
319
320 fn resolve(
321 &self,
322 pos: usize,
323 resolver: Self::Resolver,
324 out: &mut core::mem::MaybeUninit<Self::Archived>,
325 ) {
326 self.0.resolve(
327 pos + offset_of!(Self::Archived, 0),
328 resolver,
329 project_struct!(out: Self::Archived => 0),
330 );
331 }
332 }
333
334 impl<T: Serialize<S>, S: Fallible + ?Sized, const R: usize, const C: usize> Serialize<S>
335 for ArrayStorage<T, R, C>
336 {
337 fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
338 self.0.serialize(serializer)
339 }
340 }
341
342 impl<T: Archive, D: Fallible + ?Sized, const R: usize, const C: usize>
343 Deserialize<ArrayStorage<T, R, C>, D> for ArrayStorage<T::Archived, R, C>
344 where
345 T::Archived: Deserialize<T, D>,
346 {
347 fn deserialize(&self, deserializer: &mut D) -> Result<ArrayStorage<T, R, C>, D::Error> {
348 Ok(ArrayStorage(self.0.deserialize(deserializer)?))
349 }
350 }
351}