matrixmultiply/
aligned_alloc.rs

1#[cfg(not(feature = "std"))]
2use ::alloc::alloc::{self, handle_alloc_error, Layout};
3use core::{cmp, mem};
4#[cfg(feature = "std")]
5use std::alloc::{self, handle_alloc_error, Layout};
6
7#[cfg(test)]
8use core::ops::{Deref, DerefMut};
9#[cfg(test)]
10use core::slice;
11
12pub(crate) struct Alloc<T> {
13    ptr: *mut T,
14    len: usize,
15    align: usize,
16}
17
18impl<T> Alloc<T> {
19    #[inline]
20    pub unsafe fn new(nelem: usize, align: usize) -> Self {
21        let align = cmp::max(align, mem::align_of::<T>());
22        #[cfg(debug_assertions)]
23        let layout = Layout::from_size_align(mem::size_of::<T>() * nelem, align).unwrap();
24        #[cfg(not(debug_assertions))]
25        let layout = Layout::from_size_align_unchecked(mem::size_of::<T>() * nelem, align);
26        dprint!("Allocating nelem={}, layout={:?}", nelem, layout);
27        let ptr = alloc::alloc(layout);
28        if ptr.is_null() {
29            handle_alloc_error(layout);
30        }
31        Alloc {
32            ptr: ptr as *mut T,
33            len: nelem,
34            align,
35        }
36    }
37
38    #[cfg(test)]
39    pub fn init_with(mut self, elt: T) -> Alloc<T>
40    where
41        T: Copy,
42    {
43        for elt1 in &mut self[..] {
44            *elt1 = elt;
45        }
46        self
47    }
48
49    #[inline]
50    pub fn ptr_mut(&mut self) -> *mut T {
51        self.ptr
52    }
53}
54
55impl<T> Drop for Alloc<T> {
56    fn drop(&mut self) {
57        unsafe {
58            let layout =
59                Layout::from_size_align_unchecked(mem::size_of::<T>() * self.len, self.align);
60            alloc::dealloc(self.ptr as _, layout);
61        }
62    }
63}
64
65#[cfg(test)]
66impl<T> Deref for Alloc<T> {
67    type Target = [T];
68    fn deref(&self) -> &[T] {
69        unsafe { slice::from_raw_parts(self.ptr, self.len) }
70    }
71}
72
73#[cfg(test)]
74impl<T> DerefMut for Alloc<T> {
75    fn deref_mut(&mut self) -> &mut [T] {
76        unsafe { slice::from_raw_parts_mut(self.ptr, self.len) }
77    }
78}