1use crate::conversions::interpolator::BarycentricWeight;
30use crate::math::FusedMultiplyAdd;
31use num_traits::AsPrimitive;
32#[cfg(target_arch = "x86")]
33use std::arch::x86::*;
34#[cfg(target_arch = "x86_64")]
35use std::arch::x86_64::*;
36use std::ops::{Add, Mul, Sub};
37
38#[repr(align(16), C)]
39pub(crate) struct SseAlignedF32(pub(crate) [f32; 4]);
40
41#[cfg(feature = "options")]
42pub(crate) struct TetrahedralSse<'a, const GRID_SIZE: usize> {
43 pub(crate) cube: &'a [SseAlignedF32],
44}
45
46#[cfg(feature = "options")]
47pub(crate) struct PyramidalSse<'a, const GRID_SIZE: usize> {
48 pub(crate) cube: &'a [SseAlignedF32],
49}
50
51#[cfg(feature = "options")]
52pub(crate) struct PrismaticSse<'a, const GRID_SIZE: usize> {
53 pub(crate) cube: &'a [SseAlignedF32],
54}
55
56pub(crate) struct TrilinearSse<'a, const GRID_SIZE: usize> {
57 pub(crate) cube: &'a [SseAlignedF32],
58}
59
60trait Fetcher<T> {
61 fn fetch(&self, x: i32, y: i32, z: i32) -> T;
62}
63
64#[derive(Copy, Clone)]
65#[repr(transparent)]
66pub(crate) struct SseVector {
67 pub(crate) v: __m128,
68}
69
70impl From<f32> for SseVector {
71 #[inline(always)]
72 fn from(v: f32) -> Self {
73 SseVector {
74 v: unsafe { _mm_set1_ps(v) },
75 }
76 }
77}
78
79impl Sub<SseVector> for SseVector {
80 type Output = Self;
81 #[inline(always)]
82 fn sub(self, rhs: SseVector) -> Self::Output {
83 SseVector {
84 v: unsafe { _mm_sub_ps(self.v, rhs.v) },
85 }
86 }
87}
88
89impl Add<SseVector> for SseVector {
90 type Output = Self;
91 #[inline(always)]
92 fn add(self, rhs: SseVector) -> Self::Output {
93 SseVector {
94 v: unsafe { _mm_add_ps(self.v, rhs.v) },
95 }
96 }
97}
98
99impl Mul<SseVector> for SseVector {
100 type Output = Self;
101 #[inline(always)]
102 fn mul(self, rhs: SseVector) -> Self::Output {
103 SseVector {
104 v: unsafe { _mm_mul_ps(self.v, rhs.v) },
105 }
106 }
107}
108
109impl FusedMultiplyAdd<SseVector> for SseVector {
110 #[inline(always)]
111 fn mla(&self, b: SseVector, c: SseVector) -> SseVector {
112 SseVector {
113 v: unsafe { _mm_add_ps(self.v, _mm_mul_ps(b.v, c.v)) },
114 }
115 }
116}
117
118struct TetrahedralSseFetchVector<'a, const GRID_SIZE: usize> {
119 cube: &'a [SseAlignedF32],
120}
121
122impl<const GRID_SIZE: usize> Fetcher<SseVector> for TetrahedralSseFetchVector<'_, GRID_SIZE> {
123 #[inline(always)]
124 fn fetch(&self, x: i32, y: i32, z: i32) -> SseVector {
125 let offset = (x as u32 * (GRID_SIZE as u32 * GRID_SIZE as u32)
126 + y as u32 * GRID_SIZE as u32
127 + z as u32) as usize;
128 let jx = unsafe { self.cube.get_unchecked(offset..) };
129 SseVector {
130 v: unsafe { _mm_load_ps(jx.as_ptr() as *const _) },
131 }
132 }
133}
134
135pub(crate) trait SseMdInterpolation<'a, const GRID_SIZE: usize> {
136 fn new(table: &'a [SseAlignedF32]) -> Self;
137 fn inter3_sse<U: AsPrimitive<usize>, const BINS: usize>(
138 &self,
139 in_r: U,
140 in_g: U,
141 in_b: U,
142 lut: &[BarycentricWeight<f32>; BINS],
143 ) -> SseVector;
144}
145
146#[cfg(feature = "options")]
147impl<const GRID_SIZE: usize> TetrahedralSse<'_, GRID_SIZE> {
148 #[inline(always)]
149 fn interpolate<U: AsPrimitive<usize>, const BINS: usize>(
150 &self,
151 in_r: U,
152 in_g: U,
153 in_b: U,
154 lut: &[BarycentricWeight<f32>; BINS],
155 r: impl Fetcher<SseVector>,
156 ) -> SseVector {
157 let lut_r = lut[in_r.as_()];
158 let lut_g = lut[in_g.as_()];
159 let lut_b = lut[in_b.as_()];
160
161 let x: i32 = lut_r.x;
162 let y: i32 = lut_g.x;
163 let z: i32 = lut_b.x;
164
165 let x_n: i32 = lut_r.x_n;
166 let y_n: i32 = lut_g.x_n;
167 let z_n: i32 = lut_b.x_n;
168
169 let rx = lut_r.w;
170 let ry = lut_g.w;
171 let rz = lut_b.w;
172
173 let c0 = r.fetch(x, y, z);
174
175 let c2;
176 let c1;
177 let c3;
178 if rx >= ry {
179 if ry >= rz {
180 c1 = r.fetch(x_n, y, z) - c0;
182 c2 = r.fetch(x_n, y_n, z) - r.fetch(x_n, y, z);
183 c3 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y_n, z);
184 } else if rx >= rz {
185 c1 = r.fetch(x_n, y, z) - c0;
187 c2 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y, z_n);
188 c3 = r.fetch(x_n, y, z_n) - r.fetch(x_n, y, z);
189 } else {
190 c1 = r.fetch(x_n, y, z_n) - r.fetch(x, y, z_n);
192 c2 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y, z_n);
193 c3 = r.fetch(x, y, z_n) - c0;
194 }
195 } else if rx >= rz {
196 c1 = r.fetch(x_n, y_n, z) - r.fetch(x, y_n, z);
198 c2 = r.fetch(x, y_n, z) - c0;
199 c3 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y_n, z);
200 } else if ry >= rz {
201 c1 = r.fetch(x_n, y_n, z_n) - r.fetch(x, y_n, z_n);
203 c2 = r.fetch(x, y_n, z) - c0;
204 c3 = r.fetch(x, y_n, z_n) - r.fetch(x, y_n, z);
205 } else {
206 c1 = r.fetch(x_n, y_n, z_n) - r.fetch(x, y_n, z_n);
208 c2 = r.fetch(x, y_n, z_n) - r.fetch(x, y, z_n);
209 c3 = r.fetch(x, y, z_n) - c0;
210 }
211 let s0 = c0.mla(c1, SseVector::from(rx));
212 let s1 = s0.mla(c2, SseVector::from(ry));
213 s1.mla(c3, SseVector::from(rz))
214 }
215}
216
217macro_rules! define_inter_sse {
218 ($interpolator: ident) => {
219 impl<'a, const GRID_SIZE: usize> SseMdInterpolation<'a, GRID_SIZE>
220 for $interpolator<'a, GRID_SIZE>
221 {
222 #[inline]
223 fn new(table: &'a [SseAlignedF32]) -> Self {
224 Self { cube: table }
225 }
226
227 #[inline(always)]
228 fn inter3_sse<U: AsPrimitive<usize>, const BINS: usize>(
229 &self,
230 in_r: U,
231 in_g: U,
232 in_b: U,
233 lut: &[BarycentricWeight<f32>; BINS],
234 ) -> SseVector {
235 self.interpolate(
236 in_r,
237 in_g,
238 in_b,
239 lut,
240 TetrahedralSseFetchVector::<GRID_SIZE> { cube: self.cube },
241 )
242 }
243 }
244 };
245}
246
247#[cfg(feature = "options")]
248define_inter_sse!(TetrahedralSse);
249#[cfg(feature = "options")]
250define_inter_sse!(PyramidalSse);
251#[cfg(feature = "options")]
252define_inter_sse!(PrismaticSse);
253define_inter_sse!(TrilinearSse);
254
255#[cfg(feature = "options")]
256impl<const GRID_SIZE: usize> PyramidalSse<'_, GRID_SIZE> {
257 #[inline(always)]
258 fn interpolate<U: AsPrimitive<usize>, const BINS: usize>(
259 &self,
260 in_r: U,
261 in_g: U,
262 in_b: U,
263 lut: &[BarycentricWeight<f32>; BINS],
264 r: impl Fetcher<SseVector>,
265 ) -> SseVector {
266 let lut_r = lut[in_r.as_()];
267 let lut_g = lut[in_g.as_()];
268 let lut_b = lut[in_b.as_()];
269
270 let x: i32 = lut_r.x;
271 let y: i32 = lut_g.x;
272 let z: i32 = lut_b.x;
273
274 let x_n: i32 = lut_r.x_n;
275 let y_n: i32 = lut_g.x_n;
276 let z_n: i32 = lut_b.x_n;
277
278 let dr = lut_r.w;
279 let dg = lut_g.w;
280 let db = lut_b.w;
281
282 let c0 = r.fetch(x, y, z);
283
284 if dr > db && dg > db {
285 let x0 = r.fetch(x_n, y_n, z_n);
286 let x1 = r.fetch(x_n, y_n, z);
287 let x2 = r.fetch(x_n, y, z);
288 let x3 = r.fetch(x, y_n, z);
289
290 let c1 = x0 - x1;
291 let c2 = x2 - c0;
292 let c3 = x3 - c0;
293 let c4 = c0 - x3 - x2 + x1;
294
295 let s0 = c0.mla(c1, SseVector::from(db));
296 let s1 = s0.mla(c2, SseVector::from(dr));
297 let s2 = s1.mla(c3, SseVector::from(dg));
298 s2.mla(c4, SseVector::from(dr * dg))
299 } else if db > dr && dg > dr {
300 let x0 = r.fetch(x, y, z_n);
301 let x1 = r.fetch(x_n, y_n, z_n);
302 let x2 = r.fetch(x, y_n, z_n);
303 let x3 = r.fetch(x, y_n, z);
304
305 let c1 = x0 - c0;
306 let c2 = x1 - x2;
307 let c3 = x3 - c0;
308 let c4 = c0 - x3 - x0 + x2;
309
310 let s0 = c0.mla(c1, SseVector::from(db));
311 let s1 = s0.mla(c2, SseVector::from(dr));
312 let s2 = s1.mla(c3, SseVector::from(dg));
313 s2.mla(c4, SseVector::from(dg * db))
314 } else {
315 let x0 = r.fetch(x, y, z_n);
316 let x1 = r.fetch(x_n, y, z);
317 let x2 = r.fetch(x_n, y, z_n);
318 let x3 = r.fetch(x_n, y_n, z_n);
319
320 let c1 = x0 - c0;
321 let c2 = x1 - c0;
322 let c3 = x3 - x2;
323 let c4 = c0 - x1 - x0 + x2;
324
325 let s0 = c0.mla(c1, SseVector::from(db));
326 let s1 = s0.mla(c2, SseVector::from(dr));
327 let s2 = s1.mla(c3, SseVector::from(dg));
328 s2.mla(c4, SseVector::from(db * dr))
329 }
330 }
331}
332
333#[cfg(feature = "options")]
334impl<const GRID_SIZE: usize> PrismaticSse<'_, GRID_SIZE> {
335 #[inline(always)]
336 fn interpolate<U: AsPrimitive<usize>, const BINS: usize>(
337 &self,
338 in_r: U,
339 in_g: U,
340 in_b: U,
341 lut: &[BarycentricWeight<f32>; BINS],
342 r: impl Fetcher<SseVector>,
343 ) -> SseVector {
344 let lut_r = lut[in_r.as_()];
345 let lut_g = lut[in_g.as_()];
346 let lut_b = lut[in_b.as_()];
347
348 let x: i32 = lut_r.x;
349 let y: i32 = lut_g.x;
350 let z: i32 = lut_b.x;
351
352 let x_n: i32 = lut_r.x_n;
353 let y_n: i32 = lut_g.x_n;
354 let z_n: i32 = lut_b.x_n;
355
356 let dr = lut_r.w;
357 let dg = lut_g.w;
358 let db = lut_b.w;
359
360 let c0 = r.fetch(x, y, z);
361
362 if db > dr {
363 let x0 = r.fetch(x, y, z_n);
364 let x1 = r.fetch(x_n, y, z_n);
365 let x2 = r.fetch(x, y_n, z);
366 let x3 = r.fetch(x, y_n, z_n);
367 let x4 = r.fetch(x_n, y_n, z_n);
368
369 let c1 = x0 - c0;
370 let c2 = x1 - x0;
371 let c3 = x2 - c0;
372 let c4 = c0 - x2 - x0 + x3;
373 let c5 = x0 - x3 - x1 + x4;
374
375 let s0 = c0.mla(c1, SseVector::from(db));
376 let s1 = s0.mla(c2, SseVector::from(dr));
377 let s2 = s1.mla(c3, SseVector::from(dg));
378 let s3 = s2.mla(c4, SseVector::from(dg * db));
379 s3.mla(c5, SseVector::from(dr * dg))
380 } else {
381 let x0 = r.fetch(x_n, y, z);
382 let x1 = r.fetch(x_n, y, z_n);
383 let x2 = r.fetch(x, y_n, z);
384 let x3 = r.fetch(x_n, y_n, z);
385 let x4 = r.fetch(x_n, y_n, z_n);
386
387 let c1 = x1 - x0;
388 let c2 = x0 - c0;
389 let c3 = x2 - c0;
390 let c4 = x0 - x3 - x1 + x4;
391 let c5 = c0 - x2 - x0 + x3;
392
393 let s0 = c0.mla(c1, SseVector::from(db));
394 let s1 = s0.mla(c2, SseVector::from(dr));
395 let s2 = s1.mla(c3, SseVector::from(dg));
396 let s3 = s2.mla(c4, SseVector::from(dg * db));
397 s3.mla(c5, SseVector::from(dr * dg))
398 }
399 }
400}
401
402impl<const GRID_SIZE: usize> TrilinearSse<'_, GRID_SIZE> {
403 #[inline(always)]
404 fn interpolate<U: AsPrimitive<usize>, const BINS: usize>(
405 &self,
406 in_r: U,
407 in_g: U,
408 in_b: U,
409 lut: &[BarycentricWeight<f32>; BINS],
410 r: impl Fetcher<SseVector>,
411 ) -> SseVector {
412 let lut_r = lut[in_r.as_()];
413 let lut_g = lut[in_g.as_()];
414 let lut_b = lut[in_b.as_()];
415
416 let x: i32 = lut_r.x;
417 let y: i32 = lut_g.x;
418 let z: i32 = lut_b.x;
419
420 let x_n: i32 = lut_r.x_n;
421 let y_n: i32 = lut_g.x_n;
422 let z_n: i32 = lut_b.x_n;
423
424 let dr = lut_r.w;
425 let dg = lut_g.w;
426 let db = lut_b.w;
427
428 let w0 = SseVector::from(dr);
429 let w1 = SseVector::from(dg);
430 let w2 = SseVector::from(db);
431
432 let c000 = r.fetch(x, y, z);
433 let c100 = r.fetch(x_n, y, z);
434 let c010 = r.fetch(x, y_n, z);
435 let c110 = r.fetch(x_n, y_n, z);
436 let c001 = r.fetch(x, y, z_n);
437 let c101 = r.fetch(x_n, y, z_n);
438 let c011 = r.fetch(x, y_n, z_n);
439 let c111 = r.fetch(x_n, y_n, z_n);
440
441 let dx = SseVector::from(1.0 - dr);
442
443 let c00 = (c000 * dx).mla(c100, w0);
444 let c10 = (c010 * dx).mla(c110, w0);
445 let c01 = (c001 * dx).mla(c101, w0);
446 let c11 = (c011 * dx).mla(c111, w0);
447
448 let dy = SseVector::from(1.0 - dg);
449
450 let c0 = (c00 * dy).mla(c10, w1);
451 let c1 = (c01 * dy).mla(c11, w1);
452
453 let dz = SseVector::from(1.0 - db);
454
455 (c0 * dz).mla(c1, w2)
456 }
457}