moxcms/
nd_array.rs

1/*
2 * // Copyright (c) Radzivon Bartoshyk 2/2025. All rights reserved.
3 * //
4 * // Redistribution and use in source and binary forms, with or without modification,
5 * // are permitted provided that the following conditions are met:
6 * //
7 * // 1.  Redistributions of source code must retain the above copyright notice, this
8 * // list of conditions and the following disclaimer.
9 * //
10 * // 2.  Redistributions in binary form must reproduce the above copyright notice,
11 * // this list of conditions and the following disclaimer in the documentation
12 * // and/or other materials provided with the distribution.
13 * //
14 * // 3.  Neither the name of the copyright holder nor the names of its
15 * // contributors may be used to endorse or promote products derived from
16 * // this software without specific prior written permission.
17 * //
18 * // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 * // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 * // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 * // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 * // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 * // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 * // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 * // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29use crate::math::{FusedMultiplyAdd, FusedMultiplyNegAdd};
30use crate::mlaf::{mlaf, neg_mlaf};
31use crate::safe_math::{SafeAdd, SafeMul};
32use crate::{CmsError, MalformedSize, Vector3f, Vector4f};
33use std::ops::{Add, Mul, Sub};
34
35impl FusedMultiplyAdd<f32> for f32 {
36    #[inline(always)]
37    fn mla(&self, b: f32, c: f32) -> f32 {
38        mlaf(*self, b, c)
39    }
40}
41
42impl FusedMultiplyNegAdd<f32> for f32 {
43    #[inline(always)]
44    fn neg_mla(&self, b: f32, c: f32) -> f32 {
45        neg_mlaf(*self, b, c)
46    }
47}
48
49#[inline(always)]
50pub(crate) fn lerp<
51    T: Mul<Output = T>
52        + Sub<Output = T>
53        + Add<Output = T>
54        + From<f32>
55        + Copy
56        + FusedMultiplyAdd<T>
57        + FusedMultiplyNegAdd<T>,
58>(
59    a: T,
60    b: T,
61    t: T,
62) -> T {
63    a.neg_mla(a, t).mla(b, t)
64}
65
66/// 4D CLUT helper.
67///
68/// Represents hypercube.
69pub struct Hypercube<'a> {
70    array: &'a [f32],
71    x_stride: u32,
72    y_stride: u32,
73    z_stride: u32,
74    grid_size: [u8; 4],
75}
76
77trait Fetcher4<T> {
78    fn fetch(&self, x: i32, y: i32, z: i32, w: i32) -> T;
79}
80
81impl Hypercube<'_> {
82    pub fn new(array: &[f32], grid_size: usize) -> Hypercube<'_> {
83        let z_stride = grid_size as u32;
84        let y_stride = z_stride * z_stride;
85        let x_stride = z_stride * z_stride * z_stride;
86        Hypercube {
87            array,
88            x_stride,
89            y_stride,
90            z_stride,
91            grid_size: [
92                grid_size as u8,
93                grid_size as u8,
94                grid_size as u8,
95                grid_size as u8,
96            ],
97        }
98    }
99
100    pub(crate) fn new_checked(
101        array: &[f32],
102        grid_size: usize,
103        channels: usize,
104    ) -> Result<Hypercube<'_>, CmsError> {
105        if array.is_empty() || grid_size == 0 {
106            return Ok(Hypercube {
107                array,
108                x_stride: 0,
109                y_stride: 0,
110                z_stride: 0,
111                grid_size: [0, 0, 0, 0],
112            });
113        }
114        let z_stride = grid_size as u32;
115        let y_stride = z_stride * z_stride;
116        let x_stride = z_stride * z_stride * z_stride;
117
118        let last_index = (grid_size - 1)
119            .safe_mul(x_stride as usize)?
120            .safe_add((grid_size - 1).safe_mul(y_stride as usize)?)?
121            .safe_add((grid_size - 1).safe_mul(z_stride as usize)?)?
122            .safe_add(grid_size - 1)?
123            .safe_mul(channels)?;
124
125        if last_index >= array.len() {
126            return Err(CmsError::MalformedClut(MalformedSize {
127                size: array.len(),
128                expected: last_index,
129            }));
130        }
131
132        Ok(Hypercube {
133            array,
134            x_stride,
135            y_stride,
136            z_stride,
137            grid_size: [
138                grid_size as u8,
139                grid_size as u8,
140                grid_size as u8,
141                grid_size as u8,
142            ],
143        })
144    }
145
146    pub(crate) fn new_checked_hypercube(
147        array: &[f32],
148        grid_size: [u8; 4],
149        channels: usize,
150    ) -> Result<Hypercube<'_>, CmsError> {
151        if array.is_empty()
152            || grid_size[0] == 0
153            || grid_size[1] == 0
154            || grid_size[2] == 0
155            || grid_size[3] == 0
156        {
157            return Ok(Hypercube {
158                array,
159                x_stride: 0,
160                y_stride: 0,
161                z_stride: 0,
162                grid_size,
163            });
164        }
165        let z_stride = grid_size[2] as u32;
166        let y_stride = z_stride * grid_size[1] as u32;
167        let x_stride = y_stride * grid_size[0] as u32;
168        let last_index = (grid_size[0] as usize - 1)
169            .safe_mul(x_stride as usize)?
170            .safe_add((grid_size[1] as usize - 1).safe_mul(y_stride as usize)?)?
171            .safe_add((grid_size[2] as usize - 1).safe_mul(z_stride as usize)?)?
172            .safe_add(grid_size[3] as usize - 1)?
173            .safe_mul(channels)?;
174
175        if last_index >= array.len() {
176            return Err(CmsError::MalformedClut(MalformedSize {
177                size: array.len(),
178                expected: last_index,
179            }));
180        }
181
182        Ok(Hypercube {
183            array,
184            x_stride,
185            y_stride,
186            z_stride,
187            grid_size,
188        })
189    }
190
191    pub fn new_hypercube(array: &[f32], grid_size: [u8; 4]) -> Hypercube<'_> {
192        let z_stride = grid_size[2] as u32;
193        let y_stride = z_stride * grid_size[1] as u32;
194        let x_stride = y_stride * grid_size[0] as u32;
195        Hypercube {
196            array,
197            x_stride,
198            y_stride,
199            z_stride,
200            grid_size,
201        }
202    }
203}
204
205struct Fetch4Vec3<'a> {
206    array: &'a [f32],
207    x_stride: u32,
208    y_stride: u32,
209    z_stride: u32,
210}
211
212struct Fetch4Vec4<'a> {
213    array: &'a [f32],
214    x_stride: u32,
215    y_stride: u32,
216    z_stride: u32,
217}
218
219impl Fetcher4<Vector3f> for Fetch4Vec3<'_> {
220    #[inline(always)]
221    fn fetch(&self, x: i32, y: i32, z: i32, w: i32) -> Vector3f {
222        let start = (x as u32 * self.x_stride
223            + y as u32 * self.y_stride
224            + z as u32 * self.z_stride
225            + w as u32) as usize
226            * 3;
227        let k = &self.array[start..start + 3];
228        Vector3f {
229            v: [k[0], k[1], k[2]],
230        }
231    }
232}
233
234impl Fetcher4<Vector4f> for Fetch4Vec4<'_> {
235    #[inline(always)]
236    fn fetch(&self, x: i32, y: i32, z: i32, w: i32) -> Vector4f {
237        let start = (x as u32 * self.x_stride
238            + y as u32 * self.y_stride
239            + z as u32 * self.z_stride
240            + w as u32) as usize
241            * 4;
242        let k = &self.array[start..start + 4];
243        Vector4f {
244            v: [k[0], k[1], k[2], k[3]],
245        }
246    }
247}
248
249impl Hypercube<'_> {
250    #[inline(always)]
251    fn quadlinear<
252        T: From<f32>
253            + Add<T, Output = T>
254            + Mul<T, Output = T>
255            + FusedMultiplyAdd<T>
256            + Sub<T, Output = T>
257            + Copy
258            + FusedMultiplyNegAdd<T>,
259    >(
260        &self,
261        lin_x: f32,
262        lin_y: f32,
263        lin_z: f32,
264        lin_w: f32,
265        r: impl Fetcher4<T>,
266    ) -> T {
267        let lin_x = lin_x.max(0.0).min(1.0);
268        let lin_y = lin_y.max(0.0).min(1.0);
269        let lin_z = lin_z.max(0.0).min(1.0);
270        let lin_w = lin_w.max(0.0).min(1.0);
271
272        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
273        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
274        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
275        let scale_w = (self.grid_size[3] as i32 - 1) as f32;
276
277        let x = (lin_x * scale_x).floor() as i32;
278        let y = (lin_y * scale_y).floor() as i32;
279        let z = (lin_z * scale_z).floor() as i32;
280        let w = (lin_w * scale_w).floor() as i32;
281
282        let x_n = (lin_x * scale_x).ceil() as i32;
283        let y_n = (lin_y * scale_y).ceil() as i32;
284        let z_n = (lin_z * scale_z).ceil() as i32;
285        let w_n = (lin_w * scale_w).ceil() as i32;
286
287        let x_d = T::from(lin_x * scale_x - x as f32);
288        let y_d = T::from(lin_y * scale_y - y as f32);
289        let z_d = T::from(lin_z * scale_z - z as f32);
290        let w_d = T::from(lin_w * scale_w - w as f32);
291
292        let r_x1 = lerp(r.fetch(x, y, z, w), r.fetch(x_n, y, z, w), x_d);
293        let r_x2 = lerp(r.fetch(x, y_n, z, w), r.fetch(x_n, y_n, z, w), x_d);
294        let r_y1 = lerp(r_x1, r_x2, y_d);
295        let r_x3 = lerp(r.fetch(x, y, z_n, w), r.fetch(x_n, y, z_n, w), x_d);
296        let r_x4 = lerp(r.fetch(x, y_n, z_n, w), r.fetch(x_n, y_n, z_n, w), x_d);
297        let r_y2 = lerp(r_x3, r_x4, y_d);
298        let r_z1 = lerp(r_y1, r_y2, z_d);
299
300        let r_x1 = lerp(r.fetch(x, y, z, w_n), r.fetch(x_n, y, z, w_n), x_d);
301        let r_x2 = lerp(r.fetch(x, y_n, z, w_n), r.fetch(x_n, y_n, z, w_n), x_d);
302        let r_y1 = lerp(r_x1, r_x2, y_d);
303        let r_x3 = lerp(r.fetch(x, y, z_n, w_n), r.fetch(x_n, y, z_n, w_n), x_d);
304        let r_x4 = lerp(r.fetch(x, y_n, z_n, w_n), r.fetch(x_n, y_n, z_n, w_n), x_d);
305        let r_y2 = lerp(r_x3, r_x4, y_d);
306        let r_z2 = lerp(r_y1, r_y2, z_d);
307        lerp(r_z1, r_z2, w_d)
308    }
309
310    #[inline]
311    pub fn quadlinear_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32, lin_w: f32) -> Vector3f {
312        self.quadlinear(
313            lin_x,
314            lin_y,
315            lin_z,
316            lin_w,
317            Fetch4Vec3 {
318                array: self.array,
319                x_stride: self.x_stride,
320                y_stride: self.y_stride,
321                z_stride: self.z_stride,
322            },
323        )
324    }
325
326    #[inline]
327    pub fn quadlinear_vec4(&self, lin_x: f32, lin_y: f32, lin_z: f32, lin_w: f32) -> Vector4f {
328        self.quadlinear(
329            lin_x,
330            lin_y,
331            lin_z,
332            lin_w,
333            Fetch4Vec4 {
334                array: self.array,
335                x_stride: self.x_stride,
336                y_stride: self.y_stride,
337                z_stride: self.z_stride,
338            },
339        )
340    }
341
342    #[cfg(feature = "options")]
343    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
344    #[inline(always)]
345    fn pyramid<
346        T: From<f32>
347            + Add<T, Output = T>
348            + Mul<T, Output = T>
349            + FusedMultiplyAdd<T>
350            + Sub<T, Output = T>
351            + Copy
352            + FusedMultiplyNegAdd<T>,
353    >(
354        &self,
355        lin_x: f32,
356        lin_y: f32,
357        lin_z: f32,
358        lin_w: f32,
359        r: impl Fetcher4<T>,
360    ) -> T {
361        let lin_x = lin_x.max(0.0).min(1.0);
362        let lin_y = lin_y.max(0.0).min(1.0);
363        let lin_z = lin_z.max(0.0).min(1.0);
364        let lin_w = lin_w.max(0.0).min(1.0);
365
366        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
367        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
368        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
369        let scale_w = (self.grid_size[3] as i32 - 1) as f32;
370
371        let x = (lin_x * scale_x).floor() as i32;
372        let y = (lin_y * scale_y).floor() as i32;
373        let z = (lin_z * scale_z).floor() as i32;
374        let w = (lin_w * scale_w).floor() as i32;
375
376        let x_n = (lin_x * scale_x).ceil() as i32;
377        let y_n = (lin_y * scale_y).ceil() as i32;
378        let z_n = (lin_z * scale_z).ceil() as i32;
379        let w_n = (lin_w * scale_w).ceil() as i32;
380
381        let dr = lin_x * scale_x - x as f32;
382        let dg = lin_y * scale_y - y as f32;
383        let db = lin_z * scale_z - z as f32;
384        let dw = lin_w * scale_w - w as f32;
385
386        let c0 = r.fetch(x, y, z, w);
387
388        let w0 = if dr > db && dg > db {
389            let x0 = r.fetch(x_n, y_n, z_n, w);
390            let x1 = r.fetch(x_n, y_n, z, w);
391            let x2 = r.fetch(x_n, y, z, w);
392            let x3 = r.fetch(x, y_n, z, w);
393
394            let c1 = x0 - x1;
395            let c2 = x2 - c0;
396            let c3 = x3 - c0;
397            let c4 = c0 - x3 - x2 + x1;
398
399            let s0 = c0.mla(c1, T::from(db));
400            let s1 = s0.mla(c2, T::from(dr));
401            let s2 = s1.mla(c3, T::from(dg));
402            s2.mla(c4, T::from(dr * dg))
403        } else if db > dr && dg > dr {
404            let x0 = r.fetch(x, y, z_n, w);
405            let x1 = r.fetch(x_n, y_n, z_n, w);
406            let x2 = r.fetch(x, y_n, z_n, w);
407            let x3 = r.fetch(x, y_n, z, w);
408
409            let c1 = x0 - c0;
410            let c2 = x1 - x2;
411            let c3 = x3 - c0;
412            let c4 = c0 - x3 - x0 + x2;
413
414            let s0 = c0.mla(c1, T::from(db));
415            let s1 = s0.mla(c2, T::from(dr));
416            let s2 = s1.mla(c3, T::from(dg));
417            s2.mla(c4, T::from(dg * db))
418        } else {
419            let x0 = r.fetch(x, y, z_n, w);
420            let x1 = r.fetch(x_n, y, z, w);
421            let x2 = r.fetch(x_n, y, z_n, w);
422            let x3 = r.fetch(x_n, y_n, z_n, w);
423
424            let c1 = x0 - c0;
425            let c2 = x1 - c0;
426            let c3 = x3 - x2;
427            let c4 = c0 - x1 - x0 + x2;
428
429            let s0 = c0.mla(c1, T::from(db));
430            let s1 = s0.mla(c2, T::from(dr));
431            let s2 = s1.mla(c3, T::from(dg));
432            s2.mla(c4, T::from(db * dr))
433        };
434
435        let c0 = r.fetch(x, y, z, w_n);
436
437        let w1 = if dr > db && dg > db {
438            let x0 = r.fetch(x_n, y_n, z_n, w_n);
439            let x1 = r.fetch(x_n, y_n, z, w_n);
440            let x2 = r.fetch(x_n, y, z, w_n);
441            let x3 = r.fetch(x, y_n, z, w_n);
442
443            let c1 = x0 - x1;
444            let c2 = x2 - c0;
445            let c3 = x3 - c0;
446            let c4 = c0 - x3 - x2 + x1;
447
448            let s0 = c0.mla(c1, T::from(db));
449            let s1 = s0.mla(c2, T::from(dr));
450            let s2 = s1.mla(c3, T::from(dg));
451            s2.mla(c4, T::from(dr * dg))
452        } else if db > dr && dg > dr {
453            let x0 = r.fetch(x, y, z_n, w_n);
454            let x1 = r.fetch(x_n, y_n, z_n, w_n);
455            let x2 = r.fetch(x, y_n, z_n, w_n);
456            let x3 = r.fetch(x, y_n, z, w_n);
457
458            let c1 = x0 - c0;
459            let c2 = x1 - x2;
460            let c3 = x3 - c0;
461            let c4 = c0 - x3 - x0 + x2;
462
463            let s0 = c0.mla(c1, T::from(db));
464            let s1 = s0.mla(c2, T::from(dr));
465            let s2 = s1.mla(c3, T::from(dg));
466            s2.mla(c4, T::from(dg * db))
467        } else {
468            let x0 = r.fetch(x, y, z_n, w_n);
469            let x1 = r.fetch(x_n, y, z, w_n);
470            let x2 = r.fetch(x_n, y, z_n, w_n);
471            let x3 = r.fetch(x_n, y_n, z_n, w_n);
472
473            let c1 = x0 - c0;
474            let c2 = x1 - c0;
475            let c3 = x3 - x2;
476            let c4 = c0 - x1 - x0 + x2;
477
478            let s0 = c0.mla(c1, T::from(db));
479            let s1 = s0.mla(c2, T::from(dr));
480            let s2 = s1.mla(c3, T::from(dg));
481            s2.mla(c4, T::from(db * dr))
482        };
483        w0.neg_mla(w0, T::from(dw)).mla(w1, T::from(dw))
484    }
485
486    #[cfg(feature = "options")]
487    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
488    #[inline]
489    pub fn pyramid_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32, lin_w: f32) -> Vector3f {
490        self.pyramid(
491            lin_x,
492            lin_y,
493            lin_z,
494            lin_w,
495            Fetch4Vec3 {
496                array: self.array,
497                x_stride: self.x_stride,
498                y_stride: self.y_stride,
499                z_stride: self.z_stride,
500            },
501        )
502    }
503
504    #[cfg(feature = "options")]
505    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
506    #[inline]
507    pub fn pyramid_vec4(&self, lin_x: f32, lin_y: f32, lin_z: f32, lin_w: f32) -> Vector4f {
508        self.pyramid(
509            lin_x,
510            lin_y,
511            lin_z,
512            lin_w,
513            Fetch4Vec4 {
514                array: self.array,
515                x_stride: self.x_stride,
516                y_stride: self.y_stride,
517                z_stride: self.z_stride,
518            },
519        )
520    }
521
522    #[cfg(feature = "options")]
523    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
524    #[inline(always)]
525    fn prism<
526        T: From<f32>
527            + Add<T, Output = T>
528            + Mul<T, Output = T>
529            + FusedMultiplyAdd<T>
530            + Sub<T, Output = T>
531            + Copy
532            + FusedMultiplyNegAdd<T>,
533    >(
534        &self,
535        lin_x: f32,
536        lin_y: f32,
537        lin_z: f32,
538        lin_w: f32,
539        r: impl Fetcher4<T>,
540    ) -> T {
541        let lin_x = lin_x.max(0.0).min(1.0);
542        let lin_y = lin_y.max(0.0).min(1.0);
543        let lin_z = lin_z.max(0.0).min(1.0);
544        let lin_w = lin_w.max(0.0).min(1.0);
545
546        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
547        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
548        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
549        let scale_w = (self.grid_size[3] as i32 - 1) as f32;
550
551        let x = (lin_x * scale_x).floor() as i32;
552        let y = (lin_y * scale_y).floor() as i32;
553        let z = (lin_z * scale_z).floor() as i32;
554        let w = (lin_w * scale_w).floor() as i32;
555
556        let x_n = (lin_x * scale_x).ceil() as i32;
557        let y_n = (lin_y * scale_y).ceil() as i32;
558        let z_n = (lin_z * scale_z).ceil() as i32;
559        let w_n = (lin_w * scale_w).ceil() as i32;
560
561        let dr = lin_x * scale_x - x as f32;
562        let dg = lin_y * scale_y - y as f32;
563        let db = lin_z * scale_z - z as f32;
564        let dw = lin_w * scale_w - w as f32;
565
566        let c0 = r.fetch(x, y, z, w);
567
568        let w0 = if db >= dr {
569            let x0 = r.fetch(x, y, z_n, w);
570            let x1 = r.fetch(x_n, y, z_n, w);
571            let x2 = r.fetch(x, y_n, z, w);
572            let x3 = r.fetch(x, y_n, z_n, w);
573            let x4 = r.fetch(x_n, y_n, z_n, w);
574
575            let c1 = x0 - c0;
576            let c2 = x1 - x0;
577            let c3 = x2 - c0;
578            let c4 = c0 - x2 - x0 + x3;
579            let c5 = x0 - x3 - x1 + x4;
580
581            let s0 = c0.mla(c1, T::from(db));
582            let s1 = s0.mla(c2, T::from(dr));
583            let s2 = s1.mla(c3, T::from(dg));
584            let s3 = s2.mla(c4, T::from(dg * db));
585            s3.mla(c5, T::from(dr * dg))
586        } else {
587            let x0 = r.fetch(x_n, y, z, w);
588            let x1 = r.fetch(x_n, y, z_n, w);
589            let x2 = r.fetch(x, y_n, z, w);
590            let x3 = r.fetch(x_n, y_n, z, w);
591            let x4 = r.fetch(x_n, y_n, z_n, w);
592
593            let c1 = x1 - x0;
594            let c2 = x0 - c0;
595            let c3 = x2 - c0;
596            let c4 = x0 - x3 - x1 + x4;
597            let c5 = c0 - x2 - x0 + x3;
598
599            let s0 = c0.mla(c1, T::from(db));
600            let s1 = s0.mla(c2, T::from(dr));
601            let s2 = s1.mla(c3, T::from(dg));
602            let s3 = s2.mla(c4, T::from(dg * db));
603            s3.mla(c5, T::from(dr * dg))
604        };
605
606        let c0 = r.fetch(x, y, z, w_n);
607
608        let w1 = if db >= dr {
609            let x0 = r.fetch(x, y, z_n, w_n);
610            let x1 = r.fetch(x_n, y, z_n, w_n);
611            let x2 = r.fetch(x, y_n, z, w_n);
612            let x3 = r.fetch(x, y_n, z_n, w_n);
613            let x4 = r.fetch(x_n, y_n, z_n, w_n);
614
615            let c1 = x0 - c0;
616            let c2 = x1 - x0;
617            let c3 = x2 - c0;
618            let c4 = c0 - x2 - x0 + x3;
619            let c5 = x0 - x3 - x1 + x4;
620
621            let s0 = c0.mla(c1, T::from(db));
622            let s1 = s0.mla(c2, T::from(dr));
623            let s2 = s1.mla(c3, T::from(dg));
624            let s3 = s2.mla(c4, T::from(dg * db));
625            s3.mla(c5, T::from(dr * dg))
626        } else {
627            let x0 = r.fetch(x_n, y, z, w_n);
628            let x1 = r.fetch(x_n, y, z_n, w_n);
629            let x2 = r.fetch(x, y_n, z, w_n);
630            let x3 = r.fetch(x_n, y_n, z, w_n);
631            let x4 = r.fetch(x_n, y_n, z_n, w_n);
632
633            let c1 = x1 - x0;
634            let c2 = x0 - c0;
635            let c3 = x2 - c0;
636            let c4 = x0 - x3 - x1 + x4;
637            let c5 = c0 - x2 - x0 + x3;
638
639            let s0 = c0.mla(c1, T::from(db));
640            let s1 = s0.mla(c2, T::from(dr));
641            let s2 = s1.mla(c3, T::from(dg));
642            let s3 = s2.mla(c4, T::from(dg * db));
643            s3.mla(c5, T::from(dr * dg))
644        };
645        w0.neg_mla(w0, T::from(dw)).mla(w1, T::from(dw))
646    }
647
648    #[cfg(feature = "options")]
649    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
650    #[inline]
651    pub fn prism_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32, lin_w: f32) -> Vector3f {
652        self.prism(
653            lin_x,
654            lin_y,
655            lin_z,
656            lin_w,
657            Fetch4Vec3 {
658                array: self.array,
659                x_stride: self.x_stride,
660                y_stride: self.y_stride,
661                z_stride: self.z_stride,
662            },
663        )
664    }
665
666    #[cfg(feature = "options")]
667    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
668    #[inline]
669    pub fn prism_vec4(&self, lin_x: f32, lin_y: f32, lin_z: f32, lin_w: f32) -> Vector4f {
670        self.prism(
671            lin_x,
672            lin_y,
673            lin_z,
674            lin_w,
675            Fetch4Vec4 {
676                array: self.array,
677                x_stride: self.x_stride,
678                y_stride: self.y_stride,
679                z_stride: self.z_stride,
680            },
681        )
682    }
683
684    #[cfg(feature = "options")]
685    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
686    #[inline(always)]
687    fn tetra<
688        T: From<f32>
689            + Add<T, Output = T>
690            + Mul<T, Output = T>
691            + FusedMultiplyAdd<T>
692            + Sub<T, Output = T>
693            + Copy
694            + FusedMultiplyNegAdd<T>,
695    >(
696        &self,
697        lin_x: f32,
698        lin_y: f32,
699        lin_z: f32,
700        lin_w: f32,
701        r: impl Fetcher4<T>,
702    ) -> T {
703        let lin_x = lin_x.max(0.0).min(1.0);
704        let lin_y = lin_y.max(0.0).min(1.0);
705        let lin_z = lin_z.max(0.0).min(1.0);
706        let lin_w = lin_w.max(0.0).min(1.0);
707
708        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
709        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
710        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
711        let scale_w = (self.grid_size[3] as i32 - 1) as f32;
712
713        let x = (lin_x * scale_x).floor() as i32;
714        let y = (lin_y * scale_y).floor() as i32;
715        let z = (lin_z * scale_z).floor() as i32;
716        let w = (lin_w * scale_w).floor() as i32;
717
718        let x_n = (lin_x * scale_x).ceil() as i32;
719        let y_n = (lin_y * scale_y).ceil() as i32;
720        let z_n = (lin_z * scale_z).ceil() as i32;
721        let w_n = (lin_w * scale_w).ceil() as i32;
722
723        let rx = lin_x * scale_x - x as f32;
724        let ry = lin_y * scale_y - y as f32;
725        let rz = lin_z * scale_z - z as f32;
726        let rw = lin_w * scale_w - w as f32;
727
728        let c0 = r.fetch(x, y, z, w);
729        let c2;
730        let c1;
731        let c3;
732        if rx >= ry {
733            if ry >= rz {
734                //rx >= ry && ry >= rz
735                c1 = r.fetch(x_n, y, z, w) - c0;
736                c2 = r.fetch(x_n, y_n, z, w) - r.fetch(x_n, y, z, w);
737                c3 = r.fetch(x_n, y_n, z_n, w) - r.fetch(x_n, y_n, z, w);
738            } else if rx >= rz {
739                //rx >= rz && rz >= ry
740                c1 = r.fetch(x_n, y, z, w) - c0;
741                c2 = r.fetch(x_n, y_n, z_n, w) - r.fetch(x_n, y, z_n, w);
742                c3 = r.fetch(x_n, y, z_n, w) - r.fetch(x_n, y, z, w);
743            } else {
744                //rz > rx && rx >= ry
745                c1 = r.fetch(x_n, y, z_n, w) - r.fetch(x, y, z_n, w);
746                c2 = r.fetch(x_n, y_n, z_n, w) - r.fetch(x_n, y, z_n, w);
747                c3 = r.fetch(x, y, z_n, w) - c0;
748            }
749        } else if rx >= rz {
750            //ry > rx && rx >= rz
751            c1 = r.fetch(x_n, y_n, z, w) - r.fetch(x, y_n, z, w);
752            c2 = r.fetch(x, y_n, z, w) - c0;
753            c3 = r.fetch(x_n, y_n, z_n, w) - r.fetch(x_n, y_n, z, w);
754        } else if ry >= rz {
755            //ry >= rz && rz > rx
756            c1 = r.fetch(x_n, y_n, z_n, w) - r.fetch(x, y_n, z_n, w);
757            c2 = r.fetch(x, y_n, z, w) - c0;
758            c3 = r.fetch(x, y_n, z_n, w) - r.fetch(x, y_n, z, w);
759        } else {
760            //rz > ry && ry > rx
761            c1 = r.fetch(x_n, y_n, z_n, w) - r.fetch(x, y_n, z_n, w);
762            c2 = r.fetch(x, y_n, z_n, w) - r.fetch(x, y, z_n, w);
763            c3 = r.fetch(x, y, z_n, w) - c0;
764        }
765        let s0 = c0.mla(c1, T::from(rx));
766        let s1 = s0.mla(c2, T::from(ry));
767        let w0 = s1.mla(c3, T::from(rz));
768
769        let c0 = r.fetch(x, y, z, w_n);
770        let c2;
771        let c1;
772        let c3;
773        if rx >= ry {
774            if ry >= rz {
775                //rx >= ry && ry >= rz
776                c1 = r.fetch(x_n, y, z, w_n) - c0;
777                c2 = r.fetch(x_n, y_n, z, w_n) - r.fetch(x_n, y, z, w_n);
778                c3 = r.fetch(x_n, y_n, z_n, w_n) - r.fetch(x_n, y_n, z, w_n);
779            } else if rx >= rz {
780                //rx >= rz && rz >= ry
781                c1 = r.fetch(x_n, y, z, w_n) - c0;
782                c2 = r.fetch(x_n, y_n, z_n, w_n) - r.fetch(x_n, y, z_n, w_n);
783                c3 = r.fetch(x_n, y, z_n, w_n) - r.fetch(x_n, y, z, w_n);
784            } else {
785                //rz > rx && rx >= ry
786                c1 = r.fetch(x_n, y, z_n, w_n) - r.fetch(x, y, z_n, w_n);
787                c2 = r.fetch(x_n, y_n, z_n, w_n) - r.fetch(x_n, y, z_n, w_n);
788                c3 = r.fetch(x, y, z_n, w_n) - c0;
789            }
790        } else if rx >= rz {
791            //ry > rx && rx >= rz
792            c1 = r.fetch(x_n, y_n, z, w_n) - r.fetch(x, y_n, z, w_n);
793            c2 = r.fetch(x, y_n, z, w_n) - c0;
794            c3 = r.fetch(x_n, y_n, z_n, w_n) - r.fetch(x_n, y_n, z, w_n);
795        } else if ry >= rz {
796            //ry >= rz && rz > rx
797            c1 = r.fetch(x_n, y_n, z_n, w_n) - r.fetch(x, y_n, z_n, w_n);
798            c2 = r.fetch(x, y_n, z, w_n) - c0;
799            c3 = r.fetch(x, y_n, z_n, w_n) - r.fetch(x, y_n, z, w_n);
800        } else {
801            //rz > ry && ry > rx
802            c1 = r.fetch(x_n, y_n, z_n, w_n) - r.fetch(x, y_n, z_n, w_n);
803            c2 = r.fetch(x, y_n, z_n, w_n) - r.fetch(x, y, z_n, w_n);
804            c3 = r.fetch(x, y, z_n, w_n) - c0;
805        }
806        let s0 = c0.mla(c1, T::from(rx));
807        let s1 = s0.mla(c2, T::from(ry));
808        let w1 = s1.mla(c3, T::from(rz));
809        w0.neg_mla(w0, T::from(rw)).mla(w1, T::from(rw))
810    }
811
812    #[cfg(feature = "options")]
813    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
814    #[inline]
815    pub fn tetra_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32, lin_w: f32) -> Vector3f {
816        self.tetra(
817            lin_x,
818            lin_y,
819            lin_z,
820            lin_w,
821            Fetch4Vec3 {
822                array: self.array,
823                x_stride: self.x_stride,
824                y_stride: self.y_stride,
825                z_stride: self.z_stride,
826            },
827        )
828    }
829
830    #[cfg(feature = "options")]
831    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
832    #[inline]
833    pub fn tetra_vec4(&self, lin_x: f32, lin_y: f32, lin_z: f32, lin_w: f32) -> Vector4f {
834        self.tetra(
835            lin_x,
836            lin_y,
837            lin_z,
838            lin_w,
839            Fetch4Vec4 {
840                array: self.array,
841                x_stride: self.x_stride,
842                y_stride: self.y_stride,
843                z_stride: self.z_stride,
844            },
845        )
846    }
847}
848
849/// 3D CLUT helper
850///
851/// Represents hexahedron.
852pub struct Cube<'a> {
853    array: &'a [f32],
854    x_stride: u32,
855    y_stride: u32,
856    grid_size: [u8; 3],
857}
858
859pub(crate) trait ArrayFetch<T> {
860    fn fetch(&self, x: i32, y: i32, z: i32) -> T;
861}
862
863struct ArrayFetchVector3f<'a> {
864    array: &'a [f32],
865    x_stride: u32,
866    y_stride: u32,
867}
868
869impl ArrayFetch<Vector3f> for ArrayFetchVector3f<'_> {
870    #[inline(always)]
871    fn fetch(&self, x: i32, y: i32, z: i32) -> Vector3f {
872        let start = (x as u32 * self.x_stride + y as u32 * self.y_stride + z as u32) as usize * 3;
873        let k = &self.array[start..start + 3];
874        Vector3f {
875            v: [k[0], k[1], k[2]],
876        }
877    }
878}
879
880struct ArrayFetchVector4f<'a> {
881    array: &'a [f32],
882    x_stride: u32,
883    y_stride: u32,
884}
885
886impl ArrayFetch<Vector4f> for ArrayFetchVector4f<'_> {
887    #[inline(always)]
888    fn fetch(&self, x: i32, y: i32, z: i32) -> Vector4f {
889        let start = (x as u32 * self.x_stride + y as u32 * self.y_stride + z as u32) as usize * 4;
890        let k = &self.array[start..start + 4];
891        Vector4f {
892            v: [k[0], k[1], k[2], k[3]],
893        }
894    }
895}
896
897impl Cube<'_> {
898    pub fn new(array: &[f32], grid_size: usize) -> Cube<'_> {
899        let y_stride = grid_size;
900        let x_stride = y_stride * y_stride;
901        Cube {
902            array,
903            x_stride: x_stride as u32,
904            y_stride: y_stride as u32,
905            grid_size: [grid_size as u8, grid_size as u8, grid_size as u8],
906        }
907    }
908
909    pub(crate) fn new_checked(
910        array: &[f32],
911        grid_size: usize,
912        channels: usize,
913    ) -> Result<Cube<'_>, CmsError> {
914        if array.is_empty() || grid_size == 0 {
915            return Ok(Cube {
916                array,
917                x_stride: 0,
918                y_stride: 0,
919                grid_size: [0, 0, 0],
920            });
921        }
922        let y_stride = grid_size;
923        let x_stride = y_stride * y_stride;
924
925        let last_index = (grid_size - 1)
926            .safe_mul(x_stride)?
927            .safe_add((grid_size - 1).safe_mul(y_stride)?)?
928            .safe_add(grid_size - 1)?
929            .safe_mul(channels)?;
930
931        if last_index >= array.len() {
932            return Err(CmsError::MalformedClut(MalformedSize {
933                size: array.len(),
934                expected: last_index,
935            }));
936        }
937
938        Ok(Cube {
939            array,
940            x_stride: x_stride as u32,
941            y_stride: y_stride as u32,
942            grid_size: [grid_size as u8, grid_size as u8, grid_size as u8],
943        })
944    }
945
946    pub fn new_cube(array: &[f32], grid_size: [u8; 3]) -> Cube<'_> {
947        let y_stride = grid_size[2] as u32;
948        let x_stride = y_stride * grid_size[1] as u32;
949        Cube {
950            array,
951            x_stride,
952            y_stride,
953            grid_size,
954        }
955    }
956
957    pub(crate) fn new_checked_cube(
958        array: &[f32],
959        grid_size: [u8; 3],
960        channels: usize,
961    ) -> Result<Cube<'_>, CmsError> {
962        if array.is_empty() || grid_size[0] == 0 || grid_size[1] == 0 || grid_size[2] == 0 {
963            return Ok(Cube {
964                array,
965                x_stride: 0,
966                y_stride: 0,
967                grid_size,
968            });
969        }
970        let y_stride = grid_size[2] as u32;
971        let x_stride = y_stride * grid_size[1] as u32;
972        let last_index = (grid_size[0] as usize - 1)
973            .safe_mul(x_stride as usize)?
974            .safe_add((grid_size[1] as usize - 1).safe_mul(y_stride as usize)?)?
975            .safe_add(grid_size[2] as usize - 1)?
976            .safe_mul(channels)?;
977
978        if last_index >= array.len() {
979            return Err(CmsError::MalformedClut(MalformedSize {
980                size: array.len(),
981                expected: last_index,
982            }));
983        }
984
985        Ok(Cube {
986            array,
987            x_stride,
988            y_stride,
989            grid_size,
990        })
991    }
992
993    #[inline(always)]
994    fn trilinear<
995        T: Copy
996            + From<f32>
997            + Sub<T, Output = T>
998            + Mul<T, Output = T>
999            + Add<T, Output = T>
1000            + FusedMultiplyNegAdd<T>
1001            + FusedMultiplyAdd<T>,
1002    >(
1003        &self,
1004        lin_x: f32,
1005        lin_y: f32,
1006        lin_z: f32,
1007        fetch: impl ArrayFetch<T>,
1008    ) -> T {
1009        let lin_x = lin_x.max(0.0).min(1.0);
1010        let lin_y = lin_y.max(0.0).min(1.0);
1011        let lin_z = lin_z.max(0.0).min(1.0);
1012
1013        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
1014        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
1015        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
1016
1017        let x = (lin_x * scale_x).floor() as i32;
1018        let y = (lin_y * scale_y).floor() as i32;
1019        let z = (lin_z * scale_z).floor() as i32;
1020
1021        let x_n = (lin_x * scale_x).ceil() as i32;
1022        let y_n = (lin_y * scale_y).ceil() as i32;
1023        let z_n = (lin_z * scale_z).ceil() as i32;
1024
1025        let x_d = T::from(lin_x * scale_x - x as f32);
1026        let y_d = T::from(lin_y * scale_y - y as f32);
1027        let z_d = T::from(lin_z * scale_z - z as f32);
1028
1029        let c000 = fetch.fetch(x, y, z);
1030        let c100 = fetch.fetch(x_n, y, z);
1031        let c010 = fetch.fetch(x, y_n, z);
1032        let c110 = fetch.fetch(x_n, y_n, z);
1033        let c001 = fetch.fetch(x, y, z_n);
1034        let c101 = fetch.fetch(x_n, y, z_n);
1035        let c011 = fetch.fetch(x, y_n, z_n);
1036        let c111 = fetch.fetch(x_n, y_n, z_n);
1037
1038        let c00 = c000.neg_mla(c000, x_d).mla(c100, x_d);
1039        let c10 = c010.neg_mla(c010, x_d).mla(c110, x_d);
1040        let c01 = c001.neg_mla(c001, x_d).mla(c101, x_d);
1041        let c11 = c011.neg_mla(c011, x_d).mla(c111, x_d);
1042
1043        let c0 = c00.neg_mla(c00, y_d).mla(c10, y_d);
1044        let c1 = c01.neg_mla(c01, y_d).mla(c11, y_d);
1045
1046        c0.neg_mla(c0, z_d).mla(c1, z_d)
1047    }
1048
1049    #[cfg(feature = "options")]
1050    #[inline]
1051    fn pyramid<
1052        T: Copy
1053            + From<f32>
1054            + Sub<T, Output = T>
1055            + Mul<T, Output = T>
1056            + Add<T, Output = T>
1057            + FusedMultiplyAdd<T>,
1058    >(
1059        &self,
1060        lin_x: f32,
1061        lin_y: f32,
1062        lin_z: f32,
1063        fetch: impl ArrayFetch<T>,
1064    ) -> T {
1065        let lin_x = lin_x.max(0.0).min(1.0);
1066        let lin_y = lin_y.max(0.0).min(1.0);
1067        let lin_z = lin_z.max(0.0).min(1.0);
1068
1069        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
1070        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
1071        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
1072
1073        let x = (lin_x * scale_x).floor() as i32;
1074        let y = (lin_y * scale_y).floor() as i32;
1075        let z = (lin_z * scale_z).floor() as i32;
1076
1077        let x_n = (lin_x * scale_x).ceil() as i32;
1078        let y_n = (lin_y * scale_y).ceil() as i32;
1079        let z_n = (lin_z * scale_z).ceil() as i32;
1080
1081        let dr = lin_x * scale_x - x as f32;
1082        let dg = lin_y * scale_y - y as f32;
1083        let db = lin_z * scale_z - z as f32;
1084
1085        let c0 = fetch.fetch(x, y, z);
1086
1087        if dr > db && dg > db {
1088            let x0 = fetch.fetch(x_n, y_n, z_n);
1089            let x1 = fetch.fetch(x_n, y_n, z);
1090            let x2 = fetch.fetch(x_n, y, z);
1091            let x3 = fetch.fetch(x, y_n, z);
1092
1093            let c1 = x0 - x1;
1094            let c2 = x2 - c0;
1095            let c3 = x3 - c0;
1096            let c4 = c0 - x3 - x2 + x1;
1097
1098            let s0 = c0.mla(c1, T::from(db));
1099            let s1 = s0.mla(c2, T::from(dr));
1100            let s2 = s1.mla(c3, T::from(dg));
1101            s2.mla(c4, T::from(dr * dg))
1102        } else if db > dr && dg > dr {
1103            let x0 = fetch.fetch(x, y, z_n);
1104            let x1 = fetch.fetch(x_n, y_n, z_n);
1105            let x2 = fetch.fetch(x, y_n, z_n);
1106            let x3 = fetch.fetch(x, y_n, z);
1107
1108            let c1 = x0 - c0;
1109            let c2 = x1 - x2;
1110            let c3 = x3 - c0;
1111            let c4 = c0 - x3 - x0 + x2;
1112
1113            let s0 = c0.mla(c1, T::from(db));
1114            let s1 = s0.mla(c2, T::from(dr));
1115            let s2 = s1.mla(c3, T::from(dg));
1116            s2.mla(c4, T::from(dg * db))
1117        } else {
1118            let x0 = fetch.fetch(x, y, z_n);
1119            let x1 = fetch.fetch(x_n, y, z);
1120            let x2 = fetch.fetch(x_n, y, z_n);
1121            let x3 = fetch.fetch(x_n, y_n, z_n);
1122
1123            let c1 = x0 - c0;
1124            let c2 = x1 - c0;
1125            let c3 = x3 - x2;
1126            let c4 = c0 - x1 - x0 + x2;
1127
1128            let s0 = c0.mla(c1, T::from(db));
1129            let s1 = s0.mla(c2, T::from(dr));
1130            let s2 = s1.mla(c3, T::from(dg));
1131            s2.mla(c4, T::from(db * dr))
1132        }
1133    }
1134
1135    #[cfg(feature = "options")]
1136    #[inline]
1137    fn tetra<
1138        T: Copy
1139            + From<f32>
1140            + Sub<T, Output = T>
1141            + Mul<T, Output = T>
1142            + Add<T, Output = T>
1143            + FusedMultiplyAdd<T>,
1144    >(
1145        &self,
1146        lin_x: f32,
1147        lin_y: f32,
1148        lin_z: f32,
1149        fetch: impl ArrayFetch<T>,
1150    ) -> T {
1151        let lin_x = lin_x.max(0.0).min(1.0);
1152        let lin_y = lin_y.max(0.0).min(1.0);
1153        let lin_z = lin_z.max(0.0).min(1.0);
1154
1155        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
1156        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
1157        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
1158
1159        let x = (lin_x * scale_x).floor() as i32;
1160        let y = (lin_y * scale_y).floor() as i32;
1161        let z = (lin_z * scale_z).floor() as i32;
1162
1163        let x_n = (lin_x * scale_x).ceil() as i32;
1164        let y_n = (lin_y * scale_y).ceil() as i32;
1165        let z_n = (lin_z * scale_z).ceil() as i32;
1166
1167        let rx = lin_x * scale_x - x as f32;
1168        let ry = lin_y * scale_y - y as f32;
1169        let rz = lin_z * scale_z - z as f32;
1170
1171        let c0 = fetch.fetch(x, y, z);
1172        let c2;
1173        let c1;
1174        let c3;
1175        if rx >= ry {
1176            if ry >= rz {
1177                //rx >= ry && ry >= rz
1178                c1 = fetch.fetch(x_n, y, z) - c0;
1179                c2 = fetch.fetch(x_n, y_n, z) - fetch.fetch(x_n, y, z);
1180                c3 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y_n, z);
1181            } else if rx >= rz {
1182                //rx >= rz && rz >= ry
1183                c1 = fetch.fetch(x_n, y, z) - c0;
1184                c2 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y, z_n);
1185                c3 = fetch.fetch(x_n, y, z_n) - fetch.fetch(x_n, y, z);
1186            } else {
1187                //rz > rx && rx >= ry
1188                c1 = fetch.fetch(x_n, y, z_n) - fetch.fetch(x, y, z_n);
1189                c2 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y, z_n);
1190                c3 = fetch.fetch(x, y, z_n) - c0;
1191            }
1192        } else if rx >= rz {
1193            //ry > rx && rx >= rz
1194            c1 = fetch.fetch(x_n, y_n, z) - fetch.fetch(x, y_n, z);
1195            c2 = fetch.fetch(x, y_n, z) - c0;
1196            c3 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y_n, z);
1197        } else if ry >= rz {
1198            //ry >= rz && rz > rx
1199            c1 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x, y_n, z_n);
1200            c2 = fetch.fetch(x, y_n, z) - c0;
1201            c3 = fetch.fetch(x, y_n, z_n) - fetch.fetch(x, y_n, z);
1202        } else {
1203            //rz > ry && ry > rx
1204            c1 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x, y_n, z_n);
1205            c2 = fetch.fetch(x, y_n, z_n) - fetch.fetch(x, y, z_n);
1206            c3 = fetch.fetch(x, y, z_n) - c0;
1207        }
1208        let s0 = c0.mla(c1, T::from(rx));
1209        let s1 = s0.mla(c2, T::from(ry));
1210        s1.mla(c3, T::from(rz))
1211    }
1212
1213    #[cfg(feature = "options")]
1214    #[inline]
1215    fn prism<
1216        T: Copy
1217            + From<f32>
1218            + Sub<T, Output = T>
1219            + Mul<T, Output = T>
1220            + Add<T, Output = T>
1221            + FusedMultiplyAdd<T>,
1222    >(
1223        &self,
1224        lin_x: f32,
1225        lin_y: f32,
1226        lin_z: f32,
1227        fetch: impl ArrayFetch<T>,
1228    ) -> T {
1229        let lin_x = lin_x.max(0.0).min(1.0);
1230        let lin_y = lin_y.max(0.0).min(1.0);
1231        let lin_z = lin_z.max(0.0).min(1.0);
1232
1233        let scale_x = (self.grid_size[0] as i32 - 1) as f32;
1234        let scale_y = (self.grid_size[1] as i32 - 1) as f32;
1235        let scale_z = (self.grid_size[2] as i32 - 1) as f32;
1236
1237        let x = (lin_x * scale_x).floor() as i32;
1238        let y = (lin_y * scale_y).floor() as i32;
1239        let z = (lin_z * scale_z).floor() as i32;
1240
1241        let x_n = (lin_x * scale_x).ceil() as i32;
1242        let y_n = (lin_y * scale_y).ceil() as i32;
1243        let z_n = (lin_z * scale_z).ceil() as i32;
1244
1245        let dr = lin_x * scale_x - x as f32;
1246        let dg = lin_y * scale_y - y as f32;
1247        let db = lin_z * scale_z - z as f32;
1248
1249        let c0 = fetch.fetch(x, y, z);
1250
1251        if db >= dr {
1252            let x0 = fetch.fetch(x, y, z_n);
1253            let x1 = fetch.fetch(x_n, y, z_n);
1254            let x2 = fetch.fetch(x, y_n, z);
1255            let x3 = fetch.fetch(x, y_n, z_n);
1256            let x4 = fetch.fetch(x_n, y_n, z_n);
1257
1258            let c1 = x0 - c0;
1259            let c2 = x1 - x0;
1260            let c3 = x2 - c0;
1261            let c4 = c0 - x2 - x0 + x3;
1262            let c5 = x0 - x3 - x1 + x4;
1263
1264            let s0 = c0.mla(c1, T::from(db));
1265            let s1 = s0.mla(c2, T::from(dr));
1266            let s2 = s1.mla(c3, T::from(dg));
1267            let s3 = s2.mla(c4, T::from(dg * db));
1268            s3.mla(c5, T::from(dr * dg))
1269        } else {
1270            let x0 = fetch.fetch(x_n, y, z);
1271            let x1 = fetch.fetch(x_n, y, z_n);
1272            let x2 = fetch.fetch(x, y_n, z);
1273            let x3 = fetch.fetch(x_n, y_n, z);
1274            let x4 = fetch.fetch(x_n, y_n, z_n);
1275
1276            let c1 = x1 - x0;
1277            let c2 = x0 - c0;
1278            let c3 = x2 - c0;
1279            let c4 = x0 - x3 - x1 + x4;
1280            let c5 = c0 - x2 - x0 + x3;
1281
1282            let s0 = c0.mla(c1, T::from(db));
1283            let s1 = s0.mla(c2, T::from(dr));
1284            let s2 = s1.mla(c3, T::from(dg));
1285            let s3 = s2.mla(c4, T::from(dg * db));
1286            s3.mla(c5, T::from(dr * dg))
1287        }
1288    }
1289
1290    pub fn trilinear_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> Vector3f {
1291        self.trilinear(
1292            lin_x,
1293            lin_y,
1294            lin_z,
1295            ArrayFetchVector3f {
1296                array: self.array,
1297                x_stride: self.x_stride,
1298                y_stride: self.y_stride,
1299            },
1300        )
1301    }
1302
1303    #[cfg(feature = "options")]
1304    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
1305    pub fn prism_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> Vector3f {
1306        self.prism(
1307            lin_x,
1308            lin_y,
1309            lin_z,
1310            ArrayFetchVector3f {
1311                array: self.array,
1312                x_stride: self.x_stride,
1313                y_stride: self.y_stride,
1314            },
1315        )
1316    }
1317
1318    #[cfg(feature = "options")]
1319    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
1320    pub fn pyramid_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> Vector3f {
1321        self.pyramid(
1322            lin_x,
1323            lin_y,
1324            lin_z,
1325            ArrayFetchVector3f {
1326                array: self.array,
1327                x_stride: self.x_stride,
1328                y_stride: self.y_stride,
1329            },
1330        )
1331    }
1332
1333    #[cfg(feature = "options")]
1334    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
1335    pub fn tetra_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> Vector3f {
1336        self.tetra(
1337            lin_x,
1338            lin_y,
1339            lin_z,
1340            ArrayFetchVector3f {
1341                array: self.array,
1342                x_stride: self.x_stride,
1343                y_stride: self.y_stride,
1344            },
1345        )
1346    }
1347
1348    pub fn trilinear_vec4(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> Vector4f {
1349        self.trilinear(
1350            lin_x,
1351            lin_y,
1352            lin_z,
1353            ArrayFetchVector4f {
1354                array: self.array,
1355                x_stride: self.x_stride,
1356                y_stride: self.y_stride,
1357            },
1358        )
1359    }
1360
1361    #[cfg(feature = "options")]
1362    pub fn tetra_vec4(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> Vector4f {
1363        self.tetra(
1364            lin_x,
1365            lin_y,
1366            lin_z,
1367            ArrayFetchVector4f {
1368                array: self.array,
1369                x_stride: self.x_stride,
1370                y_stride: self.y_stride,
1371            },
1372        )
1373    }
1374
1375    #[cfg(feature = "options")]
1376    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
1377    pub fn pyramid_vec4(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> Vector4f {
1378        self.pyramid(
1379            lin_x,
1380            lin_y,
1381            lin_z,
1382            ArrayFetchVector4f {
1383                array: self.array,
1384                x_stride: self.x_stride,
1385                y_stride: self.y_stride,
1386            },
1387        )
1388    }
1389
1390    #[cfg(feature = "options")]
1391    #[cfg_attr(docsrs, doc(cfg(feature = "options")))]
1392    pub fn prism_vec4(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> Vector4f {
1393        self.prism(
1394            lin_x,
1395            lin_y,
1396            lin_z,
1397            ArrayFetchVector4f {
1398                array: self.array,
1399                x_stride: self.x_stride,
1400                y_stride: self.y_stride,
1401            },
1402        )
1403    }
1404}