simba/simd/
simd_value.rs

1use crate::simd::SimdBool;
2
3/// Base trait for every SIMD types.
4pub trait SimdValue: Sized {
5    /// The type of the elements of each lane of this SIMD value.
6    type Element: SimdValue<Element = Self::Element, SimdBool = bool>;
7    /// Type of the result of comparing two SIMD values like `self`.
8    type SimdBool: SimdBool;
9
10    /// The number of lanes of this SIMD value.
11    fn lanes() -> usize;
12    /// Initializes an SIMD value with each lanes set to `val`.
13    fn splat(val: Self::Element) -> Self;
14    /// Extracts the i-th lane of `self`.
15    ///
16    /// Panics if `i >= Self::lanes()`.
17    fn extract(&self, i: usize) -> Self::Element;
18    /// Extracts the i-th lane of `self` without bound-checking.
19    unsafe fn extract_unchecked(&self, i: usize) -> Self::Element;
20    /// Replaces the i-th lane of `self` by `val`.
21    ///
22    /// Panics if `i >= Self::lanes()`.
23    fn replace(&mut self, i: usize, val: Self::Element);
24    /// Replaces the i-th lane of `self` by `val` without bound-checking.
25    unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element);
26
27    /// Merges `self` and `other` depending on the lanes of `cond`.
28    ///
29    /// For each lane of `cond` with bits set to 1, the result's will contain the value of the lane of `self`.
30    /// For each lane of `cond` with bits set to 0, the result's will contain the value of the lane of `other`.
31    fn select(self, cond: Self::SimdBool, other: Self) -> Self;
32
33    /// Applies a function to each lane of `self`.
34    ///
35    /// Note that, while convenient, this method can be extremely slow as this
36    /// requires to extract each lane of `self` and then combine them again into
37    /// a new SIMD value.
38    #[inline(always)]
39    fn map_lanes(self, f: impl Fn(Self::Element) -> Self::Element) -> Self
40    where
41        Self: Clone,
42    {
43        let mut result = self.clone();
44
45        for i in 0..Self::lanes() {
46            unsafe { result.replace_unchecked(i, f(self.extract_unchecked(i))) }
47        }
48
49        result
50    }
51
52    /// Applies a function to each lane of `self` paired with the corresponding lane of `b`.
53    ///
54    /// Note that, while convenient, this method can be extremely slow as this
55    /// requires to extract each lane of `self` and then combine them again into
56    /// a new SIMD value.
57    #[inline(always)]
58    fn zip_map_lanes(
59        self,
60        b: Self,
61        f: impl Fn(Self::Element, Self::Element) -> Self::Element,
62    ) -> Self
63    where
64        Self: Clone,
65    {
66        let mut result = self.clone();
67
68        for i in 0..Self::lanes() {
69            unsafe {
70                let a = self.extract_unchecked(i);
71                let b = b.extract_unchecked(i);
72                result.replace_unchecked(i, f(a, b))
73            }
74        }
75
76        result
77    }
78}
79
80/// Marker trait implemented by SIMD and non-SIMD primitive numeric values.
81///
82/// This trait is useful for some disambiguations when writing blanked impls.
83/// This is implemented by all unsigned integer, integer, float, and complex types, as
84/// with only one lane, i.e., `f32`, `f64`, `u32`, `i64`, etc. as well as SIMD types like
85/// `f32x4, i32x8`, etc..
86pub trait PrimitiveSimdValue: Copy + SimdValue {}
87
88impl<N: SimdValue> SimdValue for num_complex::Complex<N> {
89    type Element = num_complex::Complex<N::Element>;
90    type SimdBool = N::SimdBool;
91
92    #[inline(always)]
93    fn lanes() -> usize {
94        N::lanes()
95    }
96
97    #[inline(always)]
98    fn splat(val: Self::Element) -> Self {
99        num_complex::Complex {
100            re: N::splat(val.re),
101            im: N::splat(val.im),
102        }
103    }
104
105    #[inline(always)]
106    fn extract(&self, i: usize) -> Self::Element {
107        num_complex::Complex {
108            re: self.re.extract(i),
109            im: self.im.extract(i),
110        }
111    }
112
113    #[inline(always)]
114    unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
115        num_complex::Complex {
116            re: self.re.extract_unchecked(i),
117            im: self.im.extract_unchecked(i),
118        }
119    }
120
121    #[inline(always)]
122    fn replace(&mut self, i: usize, val: Self::Element) {
123        self.re.replace(i, val.re);
124        self.im.replace(i, val.im);
125    }
126
127    #[inline(always)]
128    unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
129        self.re.replace_unchecked(i, val.re);
130        self.im.replace_unchecked(i, val.im);
131    }
132
133    #[inline(always)]
134    fn select(self, cond: Self::SimdBool, other: Self) -> Self {
135        num_complex::Complex {
136            re: self.re.select(cond, other.re),
137            im: self.im.select(cond, other.im),
138        }
139    }
140}
141
142impl<N: PrimitiveSimdValue> PrimitiveSimdValue for num_complex::Complex<N> {}
143
144macro_rules! impl_primitive_simd_value_for_scalar(
145    ($($t: ty),*) => {$(
146        impl PrimitiveSimdValue for $t {}
147        impl SimdValue for $t {
148            type Element = $t;
149            type SimdBool = bool;
150
151            #[inline(always)]
152            fn lanes() -> usize {
153                1
154            }
155
156            #[inline(always)]
157            fn splat(val: Self::Element) -> Self {
158                val
159            }
160
161            #[inline(always)]
162            fn extract(&self, _: usize) -> Self::Element {
163                *self
164            }
165
166            #[inline(always)]
167            unsafe fn extract_unchecked(&self, _: usize) -> Self::Element {
168                *self
169            }
170
171            #[inline(always)]
172            fn replace(&mut self, _: usize, val: Self::Element) {
173                *self = val
174            }
175
176            #[inline(always)]
177            unsafe fn replace_unchecked(&mut self, _: usize, val: Self::Element) {
178                *self = val
179            }
180
181            #[inline(always)]
182            fn select(self, cond: Self::SimdBool, other: Self) -> Self {
183                if cond {
184                    self
185                } else {
186                    other
187                }
188            }
189        }
190    )*}
191);
192
193impl_primitive_simd_value_for_scalar!(
194    bool, u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64
195);
196#[cfg(feature = "decimal")]
197impl_primitive_simd_value_for_scalar!(decimal::d128);