nalgebra/base/
blas_uninit.rs

1/*
2 * This file implements some BLAS operations in such a way that they work
3 * even if the first argument (the output parameter) is an uninitialized matrix.
4 *
5 * Because doing this makes the code harder to read, we only implemented the operations that we
6 * know would benefit from this performance-wise, namely, GEMM (which we use for our matrix
7 * multiplication code). If we identify other operations like that in the future, we could add
8 * them here.
9 */
10
11#[cfg(feature = "std")]
12use matrixmultiply;
13use num::{One, Zero};
14use simba::scalar::{ClosedAdd, ClosedMul};
15#[cfg(feature = "std")]
16use std::mem;
17
18use crate::base::constraint::{
19    AreMultipliable, DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
20};
21use crate::base::dimension::{Dim, Dynamic, U1};
22use crate::base::storage::{RawStorage, RawStorageMut};
23use crate::base::uninit::InitStatus;
24use crate::base::{Matrix, Scalar, Vector};
25use std::any::TypeId;
26
27// # Safety
28// The content of `y` must only contain values for which
29// `Status::assume_init_mut` is sound.
30#[allow(clippy::too_many_arguments)]
31unsafe fn array_axcpy<Status, T>(
32    _: Status,
33    y: &mut [Status::Value],
34    a: T,
35    x: &[T],
36    c: T,
37    beta: T,
38    stride1: usize,
39    stride2: usize,
40    len: usize,
41) where
42    Status: InitStatus<T>,
43    T: Scalar + Zero + ClosedAdd + ClosedMul,
44{
45    for i in 0..len {
46        let y = Status::assume_init_mut(y.get_unchecked_mut(i * stride1));
47        *y =
48            a.clone() * x.get_unchecked(i * stride2).clone() * c.clone() + beta.clone() * y.clone();
49    }
50}
51
52fn array_axc<Status, T>(
53    _: Status,
54    y: &mut [Status::Value],
55    a: T,
56    x: &[T],
57    c: T,
58    stride1: usize,
59    stride2: usize,
60    len: usize,
61) where
62    Status: InitStatus<T>,
63    T: Scalar + Zero + ClosedAdd + ClosedMul,
64{
65    for i in 0..len {
66        unsafe {
67            Status::init(
68                y.get_unchecked_mut(i * stride1),
69                a.clone() * x.get_unchecked(i * stride2).clone() * c.clone(),
70            );
71        }
72    }
73}
74
75/// Computes `y = a * x * c + b * y`.
76///
77/// If `b` is zero, `y` is never read from and may be uninitialized.
78///
79/// # Safety
80/// This is UB if b != 0 and any component of `y` is uninitialized.
81#[inline(always)]
82#[allow(clippy::many_single_char_names)]
83pub unsafe fn axcpy_uninit<Status, T, D1: Dim, D2: Dim, SA, SB>(
84    status: Status,
85    y: &mut Vector<Status::Value, D1, SA>,
86    a: T,
87    x: &Vector<T, D2, SB>,
88    c: T,
89    b: T,
90) where
91    T: Scalar + Zero + ClosedAdd + ClosedMul,
92    SA: RawStorageMut<Status::Value, D1>,
93    SB: RawStorage<T, D2>,
94    ShapeConstraint: DimEq<D1, D2>,
95    Status: InitStatus<T>,
96{
97    assert_eq!(y.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
98
99    let rstride1 = y.strides().0;
100    let rstride2 = x.strides().0;
101
102    // SAFETY: the conversion to slices is OK because we access the
103    //         elements taking the strides into account.
104    let y = y.data.as_mut_slice_unchecked();
105    let x = x.data.as_slice_unchecked();
106
107    if !b.is_zero() {
108        array_axcpy(status, y, a, x, c, b, rstride1, rstride2, x.len());
109    } else {
110        array_axc(status, y, a, x, c, rstride1, rstride2, x.len());
111    }
112}
113
114/// Computes `y = alpha * a * x + beta * y`, where `a` is a matrix, `x` a vector, and
115/// `alpha, beta` two scalars.
116///
117/// If `beta` is zero, `y` is never read from and may be uninitialized.
118///
119/// # Safety
120/// This is UB if beta != 0 and any component of `y` is uninitialized.
121#[inline(always)]
122pub unsafe fn gemv_uninit<Status, T, D1: Dim, R2: Dim, C2: Dim, D3: Dim, SA, SB, SC>(
123    status: Status,
124    y: &mut Vector<Status::Value, D1, SA>,
125    alpha: T,
126    a: &Matrix<T, R2, C2, SB>,
127    x: &Vector<T, D3, SC>,
128    beta: T,
129) where
130    Status: InitStatus<T>,
131    T: Scalar + Zero + One + ClosedAdd + ClosedMul,
132    SA: RawStorageMut<Status::Value, D1>,
133    SB: RawStorage<T, R2, C2>,
134    SC: RawStorage<T, D3>,
135    ShapeConstraint: DimEq<D1, R2> + AreMultipliable<R2, C2, D3, U1>,
136{
137    let dim1 = y.nrows();
138    let (nrows2, ncols2) = a.shape();
139    let dim3 = x.nrows();
140
141    assert!(
142        ncols2 == dim3 && dim1 == nrows2,
143        "Gemv: dimensions mismatch."
144    );
145
146    if ncols2 == 0 {
147        if beta.is_zero() {
148            y.apply(|e| Status::init(e, T::zero()));
149        } else {
150            // SAFETY: this is UB if y is uninitialized.
151            y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
152        }
153        return;
154    }
155
156    // TODO: avoid bound checks.
157    let col2 = a.column(0);
158    let val = x.vget_unchecked(0).clone();
159
160    // SAFETY: this is the call that makes this method unsafe: it is UB if Status = Uninit and beta != 0.
161    axcpy_uninit(status, y, alpha.clone(), &col2, val, beta);
162
163    for j in 1..ncols2 {
164        let col2 = a.column(j);
165        let val = x.vget_unchecked(j).clone();
166
167        // SAFETY: safe because y was initialized above.
168        axcpy_uninit(status, y, alpha.clone(), &col2, val, T::one());
169    }
170}
171
172/// Computes `y = alpha * a * b + beta * y`, where `a, b, y` are matrices.
173/// `alpha` and `beta` are scalar.
174///
175/// If `beta` is zero, `y` is never read from and may be uninitialized.
176///
177/// # Safety
178/// This is UB if beta != 0 and any component of `y` is uninitialized.
179#[inline(always)]
180pub unsafe fn gemm_uninit<
181    Status,
182    T,
183    R1: Dim,
184    C1: Dim,
185    R2: Dim,
186    C2: Dim,
187    R3: Dim,
188    C3: Dim,
189    SA,
190    SB,
191    SC,
192>(
193    status: Status,
194    y: &mut Matrix<Status::Value, R1, C1, SA>,
195    alpha: T,
196    a: &Matrix<T, R2, C2, SB>,
197    b: &Matrix<T, R3, C3, SC>,
198    beta: T,
199) where
200    Status: InitStatus<T>,
201    T: Scalar + Zero + One + ClosedAdd + ClosedMul,
202    SA: RawStorageMut<Status::Value, R1, C1>,
203    SB: RawStorage<T, R2, C2>,
204    SC: RawStorage<T, R3, C3>,
205    ShapeConstraint:
206        SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C3> + AreMultipliable<R2, C2, R3, C3>,
207{
208    let ncols1 = y.ncols();
209
210    #[cfg(feature = "std")]
211    {
212        // We assume large matrices will be Dynamic but small matrices static.
213        // We could use matrixmultiply for large statically-sized matrices but the performance
214        // threshold to activate it would be different from SMALL_DIM because our code optimizes
215        // better for statically-sized matrices.
216        if R1::is::<Dynamic>()
217            || C1::is::<Dynamic>()
218            || R2::is::<Dynamic>()
219            || C2::is::<Dynamic>()
220            || R3::is::<Dynamic>()
221            || C3::is::<Dynamic>()
222        {
223            // matrixmultiply can be used only if the std feature is available.
224            let nrows1 = y.nrows();
225            let (nrows2, ncols2) = a.shape();
226            let (nrows3, ncols3) = b.shape();
227
228            // Threshold determined empirically.
229            const SMALL_DIM: usize = 5;
230
231            if nrows1 > SMALL_DIM && ncols1 > SMALL_DIM && nrows2 > SMALL_DIM && ncols2 > SMALL_DIM
232            {
233                assert_eq!(
234                    ncols2, nrows3,
235                    "gemm: dimensions mismatch for multiplication."
236                );
237                assert_eq!(
238                    (nrows1, ncols1),
239                    (nrows2, ncols3),
240                    "gemm: dimensions mismatch for addition."
241                );
242
243                // NOTE: this case should never happen because we enter this
244                // codepath only when ncols2 > SMALL_DIM. Though we keep this
245                // here just in case if in the future we change the conditions to
246                // enter this codepath.
247                if ncols2 == 0 {
248                    // NOTE: we can't just always multiply by beta
249                    // because we documented the guaranty that `self` is
250                    // never read if `beta` is zero.
251                    if beta.is_zero() {
252                        y.apply(|e| Status::init(e, T::zero()));
253                    } else {
254                        // SAFETY: this is UB if Status = Uninit
255                        y.apply(|e| *Status::assume_init_mut(e) *= beta.clone());
256                    }
257                    return;
258                }
259
260                if TypeId::of::<T>() == TypeId::of::<f32>() {
261                    let (rsa, csa) = a.strides();
262                    let (rsb, csb) = b.strides();
263                    let (rsc, csc) = y.strides();
264
265                    matrixmultiply::sgemm(
266                        nrows2,
267                        ncols2,
268                        ncols3,
269                        mem::transmute_copy(&alpha),
270                        a.data.ptr() as *const f32,
271                        rsa as isize,
272                        csa as isize,
273                        b.data.ptr() as *const f32,
274                        rsb as isize,
275                        csb as isize,
276                        mem::transmute_copy(&beta),
277                        y.data.ptr_mut() as *mut f32,
278                        rsc as isize,
279                        csc as isize,
280                    );
281                    return;
282                } else if TypeId::of::<T>() == TypeId::of::<f64>() {
283                    let (rsa, csa) = a.strides();
284                    let (rsb, csb) = b.strides();
285                    let (rsc, csc) = y.strides();
286
287                    matrixmultiply::dgemm(
288                        nrows2,
289                        ncols2,
290                        ncols3,
291                        mem::transmute_copy(&alpha),
292                        a.data.ptr() as *const f64,
293                        rsa as isize,
294                        csa as isize,
295                        b.data.ptr() as *const f64,
296                        rsb as isize,
297                        csb as isize,
298                        mem::transmute_copy(&beta),
299                        y.data.ptr_mut() as *mut f64,
300                        rsc as isize,
301                        csc as isize,
302                    );
303                    return;
304                }
305            }
306        }
307    }
308
309    for j1 in 0..ncols1 {
310        // TODO: avoid bound checks.
311        // SAFETY: this is UB if Status = Uninit && beta != 0
312        gemv_uninit(
313            status,
314            &mut y.column_mut(j1),
315            alpha.clone(),
316            a,
317            &b.column(j1),
318            beta.clone(),
319        );
320    }
321}