moxcms/conversions/avx/
cube.rs

1/*
2 * // Copyright (c) Radzivon Bartoshyk 3/2025. All rights reserved.
3 * //
4 * // Redistribution and use in source and binary forms, with or without modification,
5 * // are permitted provided that the following conditions are met:
6 * //
7 * // 1.  Redistributions of source code must retain the above copyright notice, this
8 * // list of conditions and the following disclaimer.
9 * //
10 * // 2.  Redistributions in binary form must reproduce the above copyright notice,
11 * // this list of conditions and the following disclaimer in the documentation
12 * // and/or other materials provided with the distribution.
13 * //
14 * // 3.  Neither the name of the copyright holder nor the names of its
15 * // contributors may be used to endorse or promote products derived from
16 * // this software without specific prior written permission.
17 * //
18 * // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 * // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 * // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 * // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 * // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 * // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 * // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 * // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29use crate::conversions::avx::interpolator::AvxVectorSse;
30use crate::math::{FusedMultiplyAdd, FusedMultiplyNegAdd};
31use std::arch::x86_64::*;
32use std::ops::{Add, Mul, Sub};
33
34/// 3D CLUT NEON helper
35///
36/// Represents hexahedron.
37pub(crate) struct CubeAvxFma<'a> {
38    array: &'a [f32],
39    x_stride: u32,
40    y_stride: u32,
41    grid_size: [u8; 3],
42}
43
44struct HexahedronFetch3<'a> {
45    array: &'a [f32],
46    x_stride: u32,
47    y_stride: u32,
48}
49
50trait CubeFetch<T> {
51    fn fetch(&self, x: i32, y: i32, z: i32) -> T;
52}
53
54impl CubeFetch<AvxVectorSse> for HexahedronFetch3<'_> {
55    #[inline(always)]
56    fn fetch(&self, x: i32, y: i32, z: i32) -> AvxVectorSse {
57        let start = (x as u32 * self.x_stride + y as u32 * self.y_stride + z as u32) as usize * 3;
58        unsafe {
59            let k = self.array.get_unchecked(start..);
60            let lo = _mm_loadu_si64(k.as_ptr() as *const _);
61            let hi = _mm_insert_epi32::<2>(
62                lo,
63                k.get_unchecked(2..).as_ptr().read_unaligned().to_bits() as i32,
64            );
65            AvxVectorSse {
66                v: _mm_castsi128_ps(hi),
67            }
68        }
69    }
70}
71
72impl<'a> CubeAvxFma<'a> {
73    pub(crate) fn new(arr: &'a [f32], grid: [u8; 3], components: usize) -> Self {
74        // This is safety precondition, array size must be not less than full grid size * components.
75        // Needs to ensure that it is not missed somewhere else
76        assert_eq!(
77            grid[0] as usize * grid[1] as usize * grid[2] as usize * components,
78            arr.len()
79        );
80        let y_stride = grid[1] as u32;
81        let x_stride = y_stride * grid[0] as u32;
82        CubeAvxFma {
83            array: arr,
84            x_stride,
85            y_stride,
86            grid_size: grid,
87        }
88    }
89
90    #[inline(always)]
91    fn trilinear<
92        T: Copy
93            + From<f32>
94            + Sub<T, Output = T>
95            + Mul<T, Output = T>
96            + Add<T, Output = T>
97            + FusedMultiplyNegAdd<T>
98            + FusedMultiplyAdd<T>,
99    >(
100        &self,
101        lin_x: f32,
102        lin_y: f32,
103        lin_z: f32,
104        fetch: impl CubeFetch<T>,
105    ) -> T {
106        let lin_x = lin_x.max(0.0).min(1.0);
107        let lin_y = lin_y.max(0.0).min(1.0);
108        let lin_z = lin_z.max(0.0).min(1.0);
109
110        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
111        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
112        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
113
114        let x = (lin_x * scale_x).floor() as i32;
115        let y = (lin_y * scale_y).floor() as i32;
116        let z = (lin_z * scale_z).floor() as i32;
117
118        let x_n = (lin_x * scale_x).ceil() as i32;
119        let y_n = (lin_y * scale_y).ceil() as i32;
120        let z_n = (lin_z * scale_z).ceil() as i32;
121
122        let x_d = T::from(lin_x * scale_x - x as f32);
123        let y_d = T::from(lin_y * scale_y - y as f32);
124        let z_d = T::from(lin_z * scale_z - z as f32);
125
126        let c000 = fetch.fetch(x, y, z);
127        let c100 = fetch.fetch(x_n, y, z);
128        let c010 = fetch.fetch(x, y_n, z);
129        let c110 = fetch.fetch(x_n, y_n, z);
130        let c001 = fetch.fetch(x, y, z_n);
131        let c101 = fetch.fetch(x_n, y, z_n);
132        let c011 = fetch.fetch(x, y_n, z_n);
133        let c111 = fetch.fetch(x_n, y_n, z_n);
134
135        let c00 = c000.neg_mla(c000, x_d).mla(c100, x_d);
136        let c10 = c010.neg_mla(c010, x_d).mla(c110, x_d);
137        let c01 = c001.neg_mla(c001, x_d).mla(c101, x_d);
138        let c11 = c011.neg_mla(c011, x_d).mla(c111, x_d);
139
140        let c0 = c00.neg_mla(c00, y_d).mla(c10, y_d);
141        let c1 = c01.neg_mla(c01, y_d).mla(c11, y_d);
142
143        c0.neg_mla(c0, z_d).mla(c1, z_d)
144    }
145
146    #[cfg(feature = "options")]
147    #[inline]
148    fn pyramid<
149        T: Copy
150            + From<f32>
151            + Sub<T, Output = T>
152            + Mul<T, Output = T>
153            + Add<T, Output = T>
154            + FusedMultiplyAdd<T>,
155    >(
156        &self,
157        lin_x: f32,
158        lin_y: f32,
159        lin_z: f32,
160        fetch: impl CubeFetch<T>,
161    ) -> T {
162        let lin_x = lin_x.max(0.0).min(1.0);
163        let lin_y = lin_y.max(0.0).min(1.0);
164        let lin_z = lin_z.max(0.0).min(1.0);
165
166        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
167        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
168        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
169
170        let x = (lin_x * scale_x).floor() as i32;
171        let y = (lin_y * scale_y).floor() as i32;
172        let z = (lin_z * scale_z).floor() as i32;
173
174        let x_n = (lin_x * scale_x).ceil() as i32;
175        let y_n = (lin_y * scale_y).ceil() as i32;
176        let z_n = (lin_z * scale_z).ceil() as i32;
177
178        let dr = lin_x * scale_x - x as f32;
179        let dg = lin_y * scale_y - y as f32;
180        let db = lin_z * scale_z - z as f32;
181
182        let c0 = fetch.fetch(x, y, z);
183
184        if dr > db && dg > db {
185            let x0 = fetch.fetch(x_n, y_n, z_n);
186            let x1 = fetch.fetch(x_n, y_n, z);
187            let x2 = fetch.fetch(x_n, y, z);
188            let x3 = fetch.fetch(x, y_n, z);
189
190            let c1 = x0 - x1;
191            let c2 = x2 - c0;
192            let c3 = x3 - c0;
193            let c4 = c0 - x3 - x2 + x1;
194
195            let s0 = c0.mla(c1, T::from(db));
196            let s1 = s0.mla(c2, T::from(dr));
197            let s2 = s1.mla(c3, T::from(dg));
198            s2.mla(c4, T::from(dr * dg))
199        } else if db > dr && dg > dr {
200            let x0 = fetch.fetch(x, y, z_n);
201            let x1 = fetch.fetch(x_n, y_n, z_n);
202            let x2 = fetch.fetch(x, y_n, z_n);
203            let x3 = fetch.fetch(x, y_n, z);
204
205            let c1 = x0 - c0;
206            let c2 = x1 - x2;
207            let c3 = x3 - c0;
208            let c4 = c0 - x3 - x0 + x2;
209
210            let s0 = c0.mla(c1, T::from(db));
211            let s1 = s0.mla(c2, T::from(dr));
212            let s2 = s1.mla(c3, T::from(dg));
213            s2.mla(c4, T::from(dg * db))
214        } else {
215            let x0 = fetch.fetch(x, y, z_n);
216            let x1 = fetch.fetch(x_n, y, z);
217            let x2 = fetch.fetch(x_n, y, z_n);
218            let x3 = fetch.fetch(x_n, y_n, z_n);
219
220            let c1 = x0 - c0;
221            let c2 = x1 - c0;
222            let c3 = x3 - x2;
223            let c4 = c0 - x1 - x0 + x2;
224
225            let s0 = c0.mla(c1, T::from(db));
226            let s1 = s0.mla(c2, T::from(dr));
227            let s2 = s1.mla(c3, T::from(dg));
228            s2.mla(c4, T::from(db * dr))
229        }
230    }
231
232    #[cfg(feature = "options")]
233    #[inline]
234    fn tetra<
235        T: Copy
236            + From<f32>
237            + Sub<T, Output = T>
238            + Mul<T, Output = T>
239            + Add<T, Output = T>
240            + FusedMultiplyAdd<T>,
241    >(
242        &self,
243        lin_x: f32,
244        lin_y: f32,
245        lin_z: f32,
246        fetch: impl CubeFetch<T>,
247    ) -> T {
248        let lin_x = lin_x.max(0.0).min(1.0);
249        let lin_y = lin_y.max(0.0).min(1.0);
250        let lin_z = lin_z.max(0.0).min(1.0);
251
252        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
253        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
254        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
255
256        let x = (lin_x * scale_x).floor() as i32;
257        let y = (lin_y * scale_y).floor() as i32;
258        let z = (lin_z * scale_z).floor() as i32;
259
260        let x_n = (lin_x * scale_x).ceil() as i32;
261        let y_n = (lin_y * scale_y).ceil() as i32;
262        let z_n = (lin_z * scale_z).ceil() as i32;
263
264        let rx = lin_x * scale_x - x as f32;
265        let ry = lin_y * scale_y - y as f32;
266        let rz = lin_z * scale_z - z as f32;
267
268        let c0 = fetch.fetch(x, y, z);
269        let c2;
270        let c1;
271        let c3;
272        if rx >= ry {
273            if ry >= rz {
274                //rx >= ry && ry >= rz
275                c1 = fetch.fetch(x_n, y, z) - c0;
276                c2 = fetch.fetch(x_n, y_n, z) - fetch.fetch(x_n, y, z);
277                c3 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y_n, z);
278            } else if rx >= rz {
279                //rx >= rz && rz >= ry
280                c1 = fetch.fetch(x_n, y, z) - c0;
281                c2 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y, z_n);
282                c3 = fetch.fetch(x_n, y, z_n) - fetch.fetch(x_n, y, z);
283            } else {
284                //rz > rx && rx >= ry
285                c1 = fetch.fetch(x_n, y, z_n) - fetch.fetch(x, y, z_n);
286                c2 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y, z_n);
287                c3 = fetch.fetch(x, y, z_n) - c0;
288            }
289        } else if rx >= rz {
290            //ry > rx && rx >= rz
291            c1 = fetch.fetch(x_n, y_n, z) - fetch.fetch(x, y_n, z);
292            c2 = fetch.fetch(x, y_n, z) - c0;
293            c3 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y_n, z);
294        } else if ry >= rz {
295            //ry >= rz && rz > rx
296            c1 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x, y_n, z_n);
297            c2 = fetch.fetch(x, y_n, z) - c0;
298            c3 = fetch.fetch(x, y_n, z_n) - fetch.fetch(x, y_n, z);
299        } else {
300            //rz > ry && ry > rx
301            c1 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x, y_n, z_n);
302            c2 = fetch.fetch(x, y_n, z_n) - fetch.fetch(x, y, z_n);
303            c3 = fetch.fetch(x, y, z_n) - c0;
304        }
305        let s0 = c0.mla(c1, T::from(rx));
306        let s1 = s0.mla(c2, T::from(ry));
307        s1.mla(c3, T::from(rz))
308    }
309
310    #[cfg(feature = "options")]
311    #[inline]
312    fn prism<
313        T: Copy
314            + From<f32>
315            + Sub<T, Output = T>
316            + Mul<T, Output = T>
317            + Add<T, Output = T>
318            + FusedMultiplyAdd<T>,
319    >(
320        &self,
321        lin_x: f32,
322        lin_y: f32,
323        lin_z: f32,
324        fetch: impl CubeFetch<T>,
325    ) -> T {
326        let lin_x = lin_x.max(0.0).min(1.0);
327        let lin_y = lin_y.max(0.0).min(1.0);
328        let lin_z = lin_z.max(0.0).min(1.0);
329
330        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
331        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
332        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
333
334        let x = (lin_x * scale_x).floor() as i32;
335        let y = (lin_y * scale_y).floor() as i32;
336        let z = (lin_z * scale_z).floor() as i32;
337
338        let x_n = (lin_x * scale_x).ceil() as i32;
339        let y_n = (lin_y * scale_y).ceil() as i32;
340        let z_n = (lin_z * scale_z).ceil() as i32;
341
342        let dr = lin_x * scale_x - x as f32;
343        let dg = lin_y * scale_y - y as f32;
344        let db = lin_z * scale_z - z as f32;
345
346        let c0 = fetch.fetch(x, y, z);
347
348        if db >= dr {
349            let x0 = fetch.fetch(x, y, z_n);
350            let x1 = fetch.fetch(x_n, y, z_n);
351            let x2 = fetch.fetch(x, y_n, z);
352            let x3 = fetch.fetch(x, y_n, z_n);
353            let x4 = fetch.fetch(x_n, y_n, z_n);
354
355            let c1 = x0 - c0;
356            let c2 = x1 - x0;
357            let c3 = x2 - c0;
358            let c4 = c0 - x2 - x0 + x3;
359            let c5 = x0 - x3 - x1 + x4;
360
361            let s0 = c0.mla(c1, T::from(db));
362            let s1 = s0.mla(c2, T::from(dr));
363            let s2 = s1.mla(c3, T::from(dg));
364            let s3 = s2.mla(c4, T::from(dg * db));
365            s3.mla(c5, T::from(dr * dg))
366        } else {
367            let x0 = fetch.fetch(x_n, y, z);
368            let x1 = fetch.fetch(x_n, y, z_n);
369            let x2 = fetch.fetch(x, y_n, z);
370            let x3 = fetch.fetch(x_n, y_n, z);
371            let x4 = fetch.fetch(x_n, y_n, z_n);
372
373            let c1 = x1 - x0;
374            let c2 = x0 - c0;
375            let c3 = x2 - c0;
376            let c4 = x0 - x3 - x1 + x4;
377            let c5 = c0 - x2 - x0 + x3;
378
379            let s0 = c0.mla(c1, T::from(db));
380            let s1 = s0.mla(c2, T::from(dr));
381            let s2 = s1.mla(c3, T::from(dg));
382            let s3 = s2.mla(c4, T::from(dg * db));
383            s3.mla(c5, T::from(dr * dg))
384        }
385    }
386
387    #[inline]
388    pub(crate) fn trilinear_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> AvxVectorSse {
389        self.trilinear(
390            lin_x,
391            lin_y,
392            lin_z,
393            HexahedronFetch3 {
394                array: self.array,
395                x_stride: self.x_stride,
396                y_stride: self.y_stride,
397            },
398        )
399    }
400
401    #[cfg(feature = "options")]
402    #[inline]
403    pub(crate) fn prism_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> AvxVectorSse {
404        self.prism(
405            lin_x,
406            lin_y,
407            lin_z,
408            HexahedronFetch3 {
409                array: self.array,
410                x_stride: self.x_stride,
411                y_stride: self.y_stride,
412            },
413        )
414    }
415
416    #[cfg(feature = "options")]
417    #[inline]
418    pub(crate) fn pyramid_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> AvxVectorSse {
419        self.pyramid(
420            lin_x,
421            lin_y,
422            lin_z,
423            HexahedronFetch3 {
424                array: self.array,
425                x_stride: self.x_stride,
426                y_stride: self.y_stride,
427            },
428        )
429    }
430
431    #[cfg(feature = "options")]
432    #[inline]
433    pub(crate) fn tetra_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> AvxVectorSse {
434        self.tetra(
435            lin_x,
436            lin_y,
437            lin_z,
438            HexahedronFetch3 {
439                array: self.array,
440                x_stride: self.x_stride,
441                y_stride: self.y_stride,
442            },
443        )
444    }
445}