1use crate::conversions::avx::interpolator::AvxVectorSse;
30use crate::math::{FusedMultiplyAdd, FusedMultiplyNegAdd};
31use std::arch::x86_64::*;
32use std::ops::{Add, Mul, Sub};
33
34pub(crate) struct CubeAvxFma<'a> {
38 array: &'a [f32],
39 x_stride: u32,
40 y_stride: u32,
41 grid_size: [u8; 3],
42}
43
44struct HexahedronFetch3<'a> {
45 array: &'a [f32],
46 x_stride: u32,
47 y_stride: u32,
48}
49
50trait CubeFetch<T> {
51 fn fetch(&self, x: i32, y: i32, z: i32) -> T;
52}
53
54impl CubeFetch<AvxVectorSse> for HexahedronFetch3<'_> {
55 #[inline(always)]
56 fn fetch(&self, x: i32, y: i32, z: i32) -> AvxVectorSse {
57 let start = (x as u32 * self.x_stride + y as u32 * self.y_stride + z as u32) as usize * 3;
58 unsafe {
59 let k = self.array.get_unchecked(start..);
60 let lo = _mm_loadu_si64(k.as_ptr() as *const _);
61 let hi = _mm_insert_epi32::<2>(
62 lo,
63 k.get_unchecked(2..).as_ptr().read_unaligned().to_bits() as i32,
64 );
65 AvxVectorSse {
66 v: _mm_castsi128_ps(hi),
67 }
68 }
69 }
70}
71
72impl<'a> CubeAvxFma<'a> {
73 pub(crate) fn new(arr: &'a [f32], grid: [u8; 3], components: usize) -> Self {
74 assert_eq!(
77 grid[0] as usize * grid[1] as usize * grid[2] as usize * components,
78 arr.len()
79 );
80 let y_stride = grid[1] as u32;
81 let x_stride = y_stride * grid[0] as u32;
82 CubeAvxFma {
83 array: arr,
84 x_stride,
85 y_stride,
86 grid_size: grid,
87 }
88 }
89
90 #[inline(always)]
91 fn trilinear<
92 T: Copy
93 + From<f32>
94 + Sub<T, Output = T>
95 + Mul<T, Output = T>
96 + Add<T, Output = T>
97 + FusedMultiplyNegAdd<T>
98 + FusedMultiplyAdd<T>,
99 >(
100 &self,
101 lin_x: f32,
102 lin_y: f32,
103 lin_z: f32,
104 fetch: impl CubeFetch<T>,
105 ) -> T {
106 let lin_x = lin_x.max(0.0).min(1.0);
107 let lin_y = lin_y.max(0.0).min(1.0);
108 let lin_z = lin_z.max(0.0).min(1.0);
109
110 let scale_x = (self.grid_size[0] as i32 - 1) as f32;
111 let scale_y = (self.grid_size[1] as i32 - 1) as f32;
112 let scale_z = (self.grid_size[2] as i32 - 1) as f32;
113
114 let x = (lin_x * scale_x).floor() as i32;
115 let y = (lin_y * scale_y).floor() as i32;
116 let z = (lin_z * scale_z).floor() as i32;
117
118 let x_n = (lin_x * scale_x).ceil() as i32;
119 let y_n = (lin_y * scale_y).ceil() as i32;
120 let z_n = (lin_z * scale_z).ceil() as i32;
121
122 let x_d = T::from(lin_x * scale_x - x as f32);
123 let y_d = T::from(lin_y * scale_y - y as f32);
124 let z_d = T::from(lin_z * scale_z - z as f32);
125
126 let c000 = fetch.fetch(x, y, z);
127 let c100 = fetch.fetch(x_n, y, z);
128 let c010 = fetch.fetch(x, y_n, z);
129 let c110 = fetch.fetch(x_n, y_n, z);
130 let c001 = fetch.fetch(x, y, z_n);
131 let c101 = fetch.fetch(x_n, y, z_n);
132 let c011 = fetch.fetch(x, y_n, z_n);
133 let c111 = fetch.fetch(x_n, y_n, z_n);
134
135 let c00 = c000.neg_mla(c000, x_d).mla(c100, x_d);
136 let c10 = c010.neg_mla(c010, x_d).mla(c110, x_d);
137 let c01 = c001.neg_mla(c001, x_d).mla(c101, x_d);
138 let c11 = c011.neg_mla(c011, x_d).mla(c111, x_d);
139
140 let c0 = c00.neg_mla(c00, y_d).mla(c10, y_d);
141 let c1 = c01.neg_mla(c01, y_d).mla(c11, y_d);
142
143 c0.neg_mla(c0, z_d).mla(c1, z_d)
144 }
145
146 #[cfg(feature = "options")]
147 #[inline]
148 fn pyramid<
149 T: Copy
150 + From<f32>
151 + Sub<T, Output = T>
152 + Mul<T, Output = T>
153 + Add<T, Output = T>
154 + FusedMultiplyAdd<T>,
155 >(
156 &self,
157 lin_x: f32,
158 lin_y: f32,
159 lin_z: f32,
160 fetch: impl CubeFetch<T>,
161 ) -> T {
162 let lin_x = lin_x.max(0.0).min(1.0);
163 let lin_y = lin_y.max(0.0).min(1.0);
164 let lin_z = lin_z.max(0.0).min(1.0);
165
166 let scale_x = (self.grid_size[0] as i32 - 1) as f32;
167 let scale_y = (self.grid_size[1] as i32 - 1) as f32;
168 let scale_z = (self.grid_size[2] as i32 - 1) as f32;
169
170 let x = (lin_x * scale_x).floor() as i32;
171 let y = (lin_y * scale_y).floor() as i32;
172 let z = (lin_z * scale_z).floor() as i32;
173
174 let x_n = (lin_x * scale_x).ceil() as i32;
175 let y_n = (lin_y * scale_y).ceil() as i32;
176 let z_n = (lin_z * scale_z).ceil() as i32;
177
178 let dr = lin_x * scale_x - x as f32;
179 let dg = lin_y * scale_y - y as f32;
180 let db = lin_z * scale_z - z as f32;
181
182 let c0 = fetch.fetch(x, y, z);
183
184 if dr > db && dg > db {
185 let x0 = fetch.fetch(x_n, y_n, z_n);
186 let x1 = fetch.fetch(x_n, y_n, z);
187 let x2 = fetch.fetch(x_n, y, z);
188 let x3 = fetch.fetch(x, y_n, z);
189
190 let c1 = x0 - x1;
191 let c2 = x2 - c0;
192 let c3 = x3 - c0;
193 let c4 = c0 - x3 - x2 + x1;
194
195 let s0 = c0.mla(c1, T::from(db));
196 let s1 = s0.mla(c2, T::from(dr));
197 let s2 = s1.mla(c3, T::from(dg));
198 s2.mla(c4, T::from(dr * dg))
199 } else if db > dr && dg > dr {
200 let x0 = fetch.fetch(x, y, z_n);
201 let x1 = fetch.fetch(x_n, y_n, z_n);
202 let x2 = fetch.fetch(x, y_n, z_n);
203 let x3 = fetch.fetch(x, y_n, z);
204
205 let c1 = x0 - c0;
206 let c2 = x1 - x2;
207 let c3 = x3 - c0;
208 let c4 = c0 - x3 - x0 + x2;
209
210 let s0 = c0.mla(c1, T::from(db));
211 let s1 = s0.mla(c2, T::from(dr));
212 let s2 = s1.mla(c3, T::from(dg));
213 s2.mla(c4, T::from(dg * db))
214 } else {
215 let x0 = fetch.fetch(x, y, z_n);
216 let x1 = fetch.fetch(x_n, y, z);
217 let x2 = fetch.fetch(x_n, y, z_n);
218 let x3 = fetch.fetch(x_n, y_n, z_n);
219
220 let c1 = x0 - c0;
221 let c2 = x1 - c0;
222 let c3 = x3 - x2;
223 let c4 = c0 - x1 - x0 + x2;
224
225 let s0 = c0.mla(c1, T::from(db));
226 let s1 = s0.mla(c2, T::from(dr));
227 let s2 = s1.mla(c3, T::from(dg));
228 s2.mla(c4, T::from(db * dr))
229 }
230 }
231
232 #[cfg(feature = "options")]
233 #[inline]
234 fn tetra<
235 T: Copy
236 + From<f32>
237 + Sub<T, Output = T>
238 + Mul<T, Output = T>
239 + Add<T, Output = T>
240 + FusedMultiplyAdd<T>,
241 >(
242 &self,
243 lin_x: f32,
244 lin_y: f32,
245 lin_z: f32,
246 fetch: impl CubeFetch<T>,
247 ) -> T {
248 let lin_x = lin_x.max(0.0).min(1.0);
249 let lin_y = lin_y.max(0.0).min(1.0);
250 let lin_z = lin_z.max(0.0).min(1.0);
251
252 let scale_x = (self.grid_size[0] as i32 - 1) as f32;
253 let scale_y = (self.grid_size[1] as i32 - 1) as f32;
254 let scale_z = (self.grid_size[2] as i32 - 1) as f32;
255
256 let x = (lin_x * scale_x).floor() as i32;
257 let y = (lin_y * scale_y).floor() as i32;
258 let z = (lin_z * scale_z).floor() as i32;
259
260 let x_n = (lin_x * scale_x).ceil() as i32;
261 let y_n = (lin_y * scale_y).ceil() as i32;
262 let z_n = (lin_z * scale_z).ceil() as i32;
263
264 let rx = lin_x * scale_x - x as f32;
265 let ry = lin_y * scale_y - y as f32;
266 let rz = lin_z * scale_z - z as f32;
267
268 let c0 = fetch.fetch(x, y, z);
269 let c2;
270 let c1;
271 let c3;
272 if rx >= ry {
273 if ry >= rz {
274 c1 = fetch.fetch(x_n, y, z) - c0;
276 c2 = fetch.fetch(x_n, y_n, z) - fetch.fetch(x_n, y, z);
277 c3 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y_n, z);
278 } else if rx >= rz {
279 c1 = fetch.fetch(x_n, y, z) - c0;
281 c2 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y, z_n);
282 c3 = fetch.fetch(x_n, y, z_n) - fetch.fetch(x_n, y, z);
283 } else {
284 c1 = fetch.fetch(x_n, y, z_n) - fetch.fetch(x, y, z_n);
286 c2 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y, z_n);
287 c3 = fetch.fetch(x, y, z_n) - c0;
288 }
289 } else if rx >= rz {
290 c1 = fetch.fetch(x_n, y_n, z) - fetch.fetch(x, y_n, z);
292 c2 = fetch.fetch(x, y_n, z) - c0;
293 c3 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x_n, y_n, z);
294 } else if ry >= rz {
295 c1 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x, y_n, z_n);
297 c2 = fetch.fetch(x, y_n, z) - c0;
298 c3 = fetch.fetch(x, y_n, z_n) - fetch.fetch(x, y_n, z);
299 } else {
300 c1 = fetch.fetch(x_n, y_n, z_n) - fetch.fetch(x, y_n, z_n);
302 c2 = fetch.fetch(x, y_n, z_n) - fetch.fetch(x, y, z_n);
303 c3 = fetch.fetch(x, y, z_n) - c0;
304 }
305 let s0 = c0.mla(c1, T::from(rx));
306 let s1 = s0.mla(c2, T::from(ry));
307 s1.mla(c3, T::from(rz))
308 }
309
310 #[cfg(feature = "options")]
311 #[inline]
312 fn prism<
313 T: Copy
314 + From<f32>
315 + Sub<T, Output = T>
316 + Mul<T, Output = T>
317 + Add<T, Output = T>
318 + FusedMultiplyAdd<T>,
319 >(
320 &self,
321 lin_x: f32,
322 lin_y: f32,
323 lin_z: f32,
324 fetch: impl CubeFetch<T>,
325 ) -> T {
326 let lin_x = lin_x.max(0.0).min(1.0);
327 let lin_y = lin_y.max(0.0).min(1.0);
328 let lin_z = lin_z.max(0.0).min(1.0);
329
330 let scale_x = (self.grid_size[0] as i32 - 1) as f32;
331 let scale_y = (self.grid_size[1] as i32 - 1) as f32;
332 let scale_z = (self.grid_size[2] as i32 - 1) as f32;
333
334 let x = (lin_x * scale_x).floor() as i32;
335 let y = (lin_y * scale_y).floor() as i32;
336 let z = (lin_z * scale_z).floor() as i32;
337
338 let x_n = (lin_x * scale_x).ceil() as i32;
339 let y_n = (lin_y * scale_y).ceil() as i32;
340 let z_n = (lin_z * scale_z).ceil() as i32;
341
342 let dr = lin_x * scale_x - x as f32;
343 let dg = lin_y * scale_y - y as f32;
344 let db = lin_z * scale_z - z as f32;
345
346 let c0 = fetch.fetch(x, y, z);
347
348 if db >= dr {
349 let x0 = fetch.fetch(x, y, z_n);
350 let x1 = fetch.fetch(x_n, y, z_n);
351 let x2 = fetch.fetch(x, y_n, z);
352 let x3 = fetch.fetch(x, y_n, z_n);
353 let x4 = fetch.fetch(x_n, y_n, z_n);
354
355 let c1 = x0 - c0;
356 let c2 = x1 - x0;
357 let c3 = x2 - c0;
358 let c4 = c0 - x2 - x0 + x3;
359 let c5 = x0 - x3 - x1 + x4;
360
361 let s0 = c0.mla(c1, T::from(db));
362 let s1 = s0.mla(c2, T::from(dr));
363 let s2 = s1.mla(c3, T::from(dg));
364 let s3 = s2.mla(c4, T::from(dg * db));
365 s3.mla(c5, T::from(dr * dg))
366 } else {
367 let x0 = fetch.fetch(x_n, y, z);
368 let x1 = fetch.fetch(x_n, y, z_n);
369 let x2 = fetch.fetch(x, y_n, z);
370 let x3 = fetch.fetch(x_n, y_n, z);
371 let x4 = fetch.fetch(x_n, y_n, z_n);
372
373 let c1 = x1 - x0;
374 let c2 = x0 - c0;
375 let c3 = x2 - c0;
376 let c4 = x0 - x3 - x1 + x4;
377 let c5 = c0 - x2 - x0 + x3;
378
379 let s0 = c0.mla(c1, T::from(db));
380 let s1 = s0.mla(c2, T::from(dr));
381 let s2 = s1.mla(c3, T::from(dg));
382 let s3 = s2.mla(c4, T::from(dg * db));
383 s3.mla(c5, T::from(dr * dg))
384 }
385 }
386
387 #[inline]
388 pub(crate) fn trilinear_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> AvxVectorSse {
389 self.trilinear(
390 lin_x,
391 lin_y,
392 lin_z,
393 HexahedronFetch3 {
394 array: self.array,
395 x_stride: self.x_stride,
396 y_stride: self.y_stride,
397 },
398 )
399 }
400
401 #[cfg(feature = "options")]
402 #[inline]
403 pub(crate) fn prism_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> AvxVectorSse {
404 self.prism(
405 lin_x,
406 lin_y,
407 lin_z,
408 HexahedronFetch3 {
409 array: self.array,
410 x_stride: self.x_stride,
411 y_stride: self.y_stride,
412 },
413 )
414 }
415
416 #[cfg(feature = "options")]
417 #[inline]
418 pub(crate) fn pyramid_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> AvxVectorSse {
419 self.pyramid(
420 lin_x,
421 lin_y,
422 lin_z,
423 HexahedronFetch3 {
424 array: self.array,
425 x_stride: self.x_stride,
426 y_stride: self.y_stride,
427 },
428 )
429 }
430
431 #[cfg(feature = "options")]
432 #[inline]
433 pub(crate) fn tetra_vec3(&self, lin_x: f32, lin_y: f32, lin_z: f32) -> AvxVectorSse {
434 self.tetra(
435 lin_x,
436 lin_y,
437 lin_z,
438 HexahedronFetch3 {
439 array: self.array,
440 x_stride: self.x_stride,
441 y_stride: self.y_stride,
442 },
443 )
444 }
445}