1use crate::conversions::interpolator::BarycentricWeight;
30use crate::math::FusedMultiplyAdd;
31#[cfg(target_arch = "x86")]
32use std::arch::x86::*;
33#[cfg(target_arch = "x86_64")]
34use std::arch::x86_64::*;
35use std::ops::{Add, Mul, Sub};
36
37#[repr(align(8), C)]
38pub(crate) struct SseAlignedI16x4(pub(crate) [i16; 4]);
39
40#[cfg(feature = "options")]
41pub(crate) struct TetrahedralSseQ0_15<const GRID_SIZE: usize> {}
42
43#[cfg(feature = "options")]
44pub(crate) struct PyramidalSseQ0_15<const GRID_SIZE: usize> {}
45
46#[cfg(feature = "options")]
47pub(crate) struct PrismaticSseQ0_15<const GRID_SIZE: usize> {}
48
49pub(crate) struct TrilinearSseQ0_15<const GRID_SIZE: usize> {}
50
51trait Fetcher<T> {
52 fn fetch(&self, x: i32, y: i32, z: i32) -> T;
53}
54
55#[derive(Copy, Clone)]
56#[repr(transparent)]
57pub(crate) struct SseVector {
58 pub(crate) v: __m128i,
59}
60
61impl From<i16> for SseVector {
62 #[inline(always)]
63 fn from(v: i16) -> Self {
64 SseVector {
65 v: unsafe { _mm_set1_epi16(v) },
66 }
67 }
68}
69
70impl Sub<SseVector> for SseVector {
71 type Output = Self;
72 #[inline(always)]
73 fn sub(self, rhs: SseVector) -> Self::Output {
74 SseVector {
75 v: unsafe { _mm_sub_epi16(self.v, rhs.v) },
76 }
77 }
78}
79
80impl Add<SseVector> for SseVector {
81 type Output = Self;
82 #[inline(always)]
83 fn add(self, rhs: SseVector) -> Self::Output {
84 SseVector {
85 v: unsafe { _mm_add_epi16(self.v, rhs.v) },
86 }
87 }
88}
89
90impl Mul<SseVector> for SseVector {
91 type Output = Self;
92 #[inline(always)]
93 fn mul(self, rhs: SseVector) -> Self::Output {
94 SseVector {
95 v: unsafe { _mm_mulhrs_epi16(self.v, rhs.v) },
96 }
97 }
98}
99
100impl FusedMultiplyAdd<SseVector> for SseVector {
101 #[inline(always)]
102 fn mla(&self, b: SseVector, c: SseVector) -> SseVector {
103 SseVector {
104 v: unsafe { _mm_add_epi16(self.v, _mm_mulhrs_epi16(b.v, c.v)) },
105 }
106 }
107}
108
109struct TetrahedralSseQ0_15FetchVector<'a, const GRID_SIZE: usize> {
110 cube: &'a [SseAlignedI16x4],
111}
112
113impl<const GRID_SIZE: usize> Fetcher<SseVector> for TetrahedralSseQ0_15FetchVector<'_, GRID_SIZE> {
114 #[inline(always)]
115 fn fetch(&self, x: i32, y: i32, z: i32) -> SseVector {
116 let offset = (x as u32 * (GRID_SIZE as u32 * GRID_SIZE as u32)
117 + y as u32 * GRID_SIZE as u32
118 + z as u32) as usize;
119 let jx = unsafe { self.cube.get_unchecked(offset..) };
120 SseVector {
121 v: unsafe { _mm_loadu_si64(jx.as_ptr() as *const _) },
122 }
123 }
124}
125
126pub(crate) trait SseMdInterpolationQ0_15 {
127 fn inter3_sse(
128 &self,
129 cube: &[SseAlignedI16x4],
130 in_r: usize,
131 in_g: usize,
132 in_b: usize,
133 lut: &[BarycentricWeight<i16>],
134 ) -> SseVector;
135}
136
137#[cfg(feature = "options")]
138impl<const GRID_SIZE: usize> TetrahedralSseQ0_15<GRID_SIZE> {
139 #[target_feature(enable = "sse4.1")]
140 unsafe fn interpolate(
141 &self,
142 in_r: usize,
143 in_g: usize,
144 in_b: usize,
145 lut: &[BarycentricWeight<i16>],
146 r: impl Fetcher<SseVector>,
147 ) -> SseVector {
148 let lut_r = unsafe { *lut.get_unchecked(in_r) };
149 let lut_g = unsafe { *lut.get_unchecked(in_g) };
150 let lut_b = unsafe { *lut.get_unchecked(in_b) };
151
152 let x: i32 = lut_r.x;
153 let y: i32 = lut_g.x;
154 let z: i32 = lut_b.x;
155
156 let x_n: i32 = lut_r.x_n;
157 let y_n: i32 = lut_g.x_n;
158 let z_n: i32 = lut_b.x_n;
159
160 let rx = lut_r.w;
161 let ry = lut_g.w;
162 let rz = lut_b.w;
163
164 let c0 = r.fetch(x, y, z);
165
166 let c2;
167 let c1;
168 let c3;
169 if rx >= ry {
170 if ry >= rz {
171 c1 = r.fetch(x_n, y, z) - c0;
173 c2 = r.fetch(x_n, y_n, z) - r.fetch(x_n, y, z);
174 c3 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y_n, z);
175 } else if rx >= rz {
176 c1 = r.fetch(x_n, y, z) - c0;
178 c2 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y, z_n);
179 c3 = r.fetch(x_n, y, z_n) - r.fetch(x_n, y, z);
180 } else {
181 c1 = r.fetch(x_n, y, z_n) - r.fetch(x, y, z_n);
183 c2 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y, z_n);
184 c3 = r.fetch(x, y, z_n) - c0;
185 }
186 } else if rx >= rz {
187 c1 = r.fetch(x_n, y_n, z) - r.fetch(x, y_n, z);
189 c2 = r.fetch(x, y_n, z) - c0;
190 c3 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y_n, z);
191 } else if ry >= rz {
192 c1 = r.fetch(x_n, y_n, z_n) - r.fetch(x, y_n, z_n);
194 c2 = r.fetch(x, y_n, z) - c0;
195 c3 = r.fetch(x, y_n, z_n) - r.fetch(x, y_n, z);
196 } else {
197 c1 = r.fetch(x_n, y_n, z_n) - r.fetch(x, y_n, z_n);
199 c2 = r.fetch(x, y_n, z_n) - r.fetch(x, y, z_n);
200 c3 = r.fetch(x, y, z_n) - c0;
201 }
202 let s0 = c0.mla(c1, SseVector::from(rx));
203 let s1 = s0.mla(c2, SseVector::from(ry));
204 s1.mla(c3, SseVector::from(rz))
205 }
206}
207
208macro_rules! define_inter_sse {
209 ($interpolator: ident) => {
210 impl<const GRID_SIZE: usize> SseMdInterpolationQ0_15 for $interpolator<GRID_SIZE> {
211 fn inter3_sse(
212 &self,
213 table: &[SseAlignedI16x4],
214 in_r: usize,
215 in_g: usize,
216 in_b: usize,
217 lut: &[BarycentricWeight<i16>],
218 ) -> SseVector {
219 unsafe {
220 self.interpolate(
221 in_r,
222 in_g,
223 in_b,
224 lut,
225 TetrahedralSseQ0_15FetchVector::<GRID_SIZE> { cube: table },
226 )
227 }
228 }
229 }
230 };
231}
232
233#[cfg(feature = "options")]
234define_inter_sse!(TetrahedralSseQ0_15);
235#[cfg(feature = "options")]
236define_inter_sse!(PyramidalSseQ0_15);
237#[cfg(feature = "options")]
238define_inter_sse!(PrismaticSseQ0_15);
239define_inter_sse!(TrilinearSseQ0_15);
240
241#[cfg(feature = "options")]
242impl<const GRID_SIZE: usize> PyramidalSseQ0_15<GRID_SIZE> {
243 #[target_feature(enable = "sse4.1")]
244 unsafe fn interpolate(
245 &self,
246 in_r: usize,
247 in_g: usize,
248 in_b: usize,
249 lut: &[BarycentricWeight<i16>],
250 r: impl Fetcher<SseVector>,
251 ) -> SseVector {
252 let lut_r = unsafe { *lut.get_unchecked(in_r) };
253 let lut_g = unsafe { *lut.get_unchecked(in_g) };
254 let lut_b = unsafe { *lut.get_unchecked(in_b) };
255
256 let x: i32 = lut_r.x;
257 let y: i32 = lut_g.x;
258 let z: i32 = lut_b.x;
259
260 let x_n: i32 = lut_r.x_n;
261 let y_n: i32 = lut_g.x_n;
262 let z_n: i32 = lut_b.x_n;
263
264 let dr = lut_r.w;
265 let dg = lut_g.w;
266 let db = lut_b.w;
267
268 let c0 = r.fetch(x, y, z);
269
270 if dr > db && dg > db {
271 let x0 = r.fetch(x_n, y_n, z_n);
272 let x1 = r.fetch(x_n, y_n, z);
273 let x2 = r.fetch(x_n, y, z);
274 let x3 = r.fetch(x, y_n, z);
275
276 let c1 = x0 - x1;
277 let c2 = x2 - c0;
278 let c3 = x3 - c0;
279 let c4 = c0 - x3 - x2 + x1;
280
281 let s0 = c0.mla(c1, SseVector::from(db));
282 let s1 = s0.mla(c2, SseVector::from(dr));
283 let s2 = s1.mla(c3, SseVector::from(dg));
284 s2.mla(c4, SseVector::from(dr) * SseVector::from(dg))
285 } else if db > dr && dg > dr {
286 let x0 = r.fetch(x, y, z_n);
287 let x1 = r.fetch(x_n, y_n, z_n);
288 let x2 = r.fetch(x, y_n, z_n);
289 let x3 = r.fetch(x, y_n, z);
290
291 let c1 = x0 - c0;
292 let c2 = x1 - x2;
293 let c3 = x3 - c0;
294 let c4 = c0 - x3 - x0 + x2;
295
296 let s0 = c0.mla(c1, SseVector::from(db));
297 let s1 = s0.mla(c2, SseVector::from(dr));
298 let s2 = s1.mla(c3, SseVector::from(dg));
299 s2.mla(c4, SseVector::from(dg) * SseVector::from(db))
300 } else {
301 let x0 = r.fetch(x, y, z_n);
302 let x1 = r.fetch(x_n, y, z);
303 let x2 = r.fetch(x_n, y, z_n);
304 let x3 = r.fetch(x_n, y_n, z_n);
305
306 let c1 = x0 - c0;
307 let c2 = x1 - c0;
308 let c3 = x3 - x2;
309 let c4 = c0 - x1 - x0 + x2;
310
311 let s0 = c0.mla(c1, SseVector::from(db));
312 let s1 = s0.mla(c2, SseVector::from(dr));
313 let s2 = s1.mla(c3, SseVector::from(dg));
314 s2.mla(c4, SseVector::from(db) * SseVector::from(dr))
315 }
316 }
317}
318
319#[cfg(feature = "options")]
320impl<const GRID_SIZE: usize> PrismaticSseQ0_15<GRID_SIZE> {
321 #[target_feature(enable = "sse4.1")]
322 unsafe fn interpolate(
323 &self,
324 in_r: usize,
325 in_g: usize,
326 in_b: usize,
327 lut: &[BarycentricWeight<i16>],
328 r: impl Fetcher<SseVector>,
329 ) -> SseVector {
330 let lut_r = unsafe { *lut.get_unchecked(in_r) };
331 let lut_g = unsafe { *lut.get_unchecked(in_g) };
332 let lut_b = unsafe { *lut.get_unchecked(in_b) };
333
334 let x: i32 = lut_r.x;
335 let y: i32 = lut_g.x;
336 let z: i32 = lut_b.x;
337
338 let x_n: i32 = lut_r.x_n;
339 let y_n: i32 = lut_g.x_n;
340 let z_n: i32 = lut_b.x_n;
341
342 let dr = lut_r.w;
343 let dg = lut_g.w;
344 let db = lut_b.w;
345
346 let c0 = r.fetch(x, y, z);
347
348 if db > dr {
349 let x0 = r.fetch(x, y, z_n);
350 let x1 = r.fetch(x_n, y, z_n);
351 let x2 = r.fetch(x, y_n, z);
352 let x3 = r.fetch(x, y_n, z_n);
353 let x4 = r.fetch(x_n, y_n, z_n);
354
355 let c1 = x0 - c0;
356 let c2 = x1 - x0;
357 let c3 = x2 - c0;
358 let c4 = c0 - x2 - x0 + x3;
359 let c5 = x0 - x3 - x1 + x4;
360
361 let s0 = c0.mla(c1, SseVector::from(db));
362 let s1 = s0.mla(c2, SseVector::from(dr));
363 let s2 = s1.mla(c3, SseVector::from(dg));
364 let s3 = s2.mla(c4, SseVector::from(dg) * SseVector::from(db));
365 s3.mla(c5, SseVector::from(dr) * SseVector::from(dg))
366 } else {
367 let x0 = r.fetch(x_n, y, z);
368 let x1 = r.fetch(x_n, y, z_n);
369 let x2 = r.fetch(x, y_n, z);
370 let x3 = r.fetch(x_n, y_n, z);
371 let x4 = r.fetch(x_n, y_n, z_n);
372
373 let c1 = x1 - x0;
374 let c2 = x0 - c0;
375 let c3 = x2 - c0;
376 let c4 = x0 - x3 - x1 + x4;
377 let c5 = c0 - x2 - x0 + x3;
378
379 let s0 = c0.mla(c1, SseVector::from(db));
380 let s1 = s0.mla(c2, SseVector::from(dr));
381 let s2 = s1.mla(c3, SseVector::from(dg));
382 let s3 = s2.mla(c4, SseVector::from(dg) * SseVector::from(db));
383 s3.mla(c5, SseVector::from(dr) * SseVector::from(dg))
384 }
385 }
386}
387
388impl<const GRID_SIZE: usize> TrilinearSseQ0_15<GRID_SIZE> {
389 #[target_feature(enable = "sse4.1")]
390 unsafe fn interpolate(
391 &self,
392 in_r: usize,
393 in_g: usize,
394 in_b: usize,
395 lut: &[BarycentricWeight<i16>],
396 r: impl Fetcher<SseVector>,
397 ) -> SseVector {
398 let lut_r = unsafe { *lut.get_unchecked(in_r) };
399 let lut_g = unsafe { *lut.get_unchecked(in_g) };
400 let lut_b = unsafe { *lut.get_unchecked(in_b) };
401
402 let x: i32 = lut_r.x;
403 let y: i32 = lut_g.x;
404 let z: i32 = lut_b.x;
405
406 let x_n: i32 = lut_r.x_n;
407 let y_n: i32 = lut_g.x_n;
408 let z_n: i32 = lut_b.x_n;
409
410 let dr = lut_r.w;
411 let dg = lut_g.w;
412 let db = lut_b.w;
413
414 const Q_MAX: i16 = ((1i32 << 15i32) - 1) as i16;
415 let q_max = SseVector::from(Q_MAX);
416 let w0 = SseVector::from(dr);
417 let w1 = SseVector::from(dg);
418 let w2 = SseVector::from(db);
419 let dx = q_max - SseVector::from(dr);
420 let dy = q_max - SseVector::from(dg);
421 let dz = q_max - SseVector::from(db);
422
423 let c000 = r.fetch(x, y, z);
424 let c100 = r.fetch(x_n, y, z);
425 let c010 = r.fetch(x, y_n, z);
426 let c110 = r.fetch(x_n, y_n, z);
427 let c001 = r.fetch(x, y, z_n);
428 let c101 = r.fetch(x_n, y, z_n);
429 let c011 = r.fetch(x, y_n, z_n);
430 let c111 = r.fetch(x_n, y_n, z_n);
431
432 let c00 = (c000 * dx).mla(c100, w0);
433 let c10 = (c010 * dx).mla(c110, w0);
434 let c01 = (c001 * dx).mla(c101, w0);
435 let c11 = (c011 * dx).mla(c111, w0);
436
437 let c0 = (c00 * dy).mla(c10, w1);
438 let c1 = (c01 * dy).mla(c11, w1);
439
440 (c0 * dz).mla(c1, w2)
441 }
442}