1use crate::conversions::katana::{KatanaFinalStage, KatanaInitialStage};
30use crate::mlaf::mlaf;
31use crate::safe_math::SafeMul;
32use crate::trc::lut_interp_linear_float;
33use crate::{
34 CmsError, Cube, DataColorSpace, InterpolationMethod, LutMultidimensionalType, MalformedSize,
35 Matrix3d, Matrix3f, PointeeSizeExpressible, TransformOptions, Vector3d, Vector3f,
36};
37use num_traits::AsPrimitive;
38use std::marker::PhantomData;
39
40#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
41pub(crate) enum MultidimensionalDirection {
42 DeviceToPcs,
43 PcsToDevice,
44}
45
46struct Multidimensional3x3<
47 T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
48> {
49 a_curves: Option<Box<[Vec<f32>; 3]>>,
50 m_curves: Option<Box<[Vec<f32>; 3]>>,
51 b_curves: Option<Box<[Vec<f32>; 3]>>,
52 clut: Option<Vec<f32>>,
53 matrix: Matrix3f,
54 bias: Vector3f,
55 direction: MultidimensionalDirection,
56 options: TransformOptions,
57 pcs: DataColorSpace,
58 grid_size: [u8; 3],
59 _phantom: PhantomData<T>,
60 bit_depth: usize,
61}
62
63impl<T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync>
64 Multidimensional3x3<T>
65{
66 fn execute_matrix_stage(&self, dst: &mut [f32]) {
67 let m = self.matrix;
68 let b = self.bias;
69
70 if !m.test_equality(Matrix3f::IDENTITY) || !b.eq(&Vector3f::default()) {
71 for dst in dst.chunks_exact_mut(3) {
72 let x = dst[0];
73 let y = dst[1];
74 let z = dst[2];
75 dst[0] = mlaf(mlaf(mlaf(b.v[0], x, m.v[0][0]), y, m.v[0][1]), z, m.v[0][2]);
76 dst[1] = mlaf(mlaf(mlaf(b.v[1], x, m.v[1][0]), y, m.v[1][1]), z, m.v[1][2]);
77 dst[2] = mlaf(mlaf(mlaf(b.v[2], x, m.v[2][0]), y, m.v[2][1]), z, m.v[2][2]);
78 }
79 }
80 }
81
82 fn execute_simple_curves(&self, dst: &mut [f32], curves: &[Vec<f32>; 3]) {
83 let curve0 = &curves[0];
84 let curve1 = &curves[1];
85 let curve2 = &curves[2];
86
87 for dst in dst.chunks_exact_mut(3) {
88 let a0 = dst[0];
89 let a1 = dst[1];
90 let a2 = dst[2];
91 let b0 = lut_interp_linear_float(a0, curve0);
92 let b1 = lut_interp_linear_float(a1, curve1);
93 let b2 = lut_interp_linear_float(a2, curve2);
94 dst[0] = b0;
95 dst[1] = b1;
96 dst[2] = b2;
97 }
98 }
99
100 fn to_pcs_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
101 &self,
102 input: &[T],
103 dst: &mut [f32],
104 fetch: Fetch,
105 ) -> Result<(), CmsError> {
106 let norm_value = if T::FINITE {
107 1.0 / ((1u32 << self.bit_depth) - 1) as f32
108 } else {
109 1.0
110 };
111 assert_eq!(
112 self.direction,
113 MultidimensionalDirection::DeviceToPcs,
114 "PCS to device cannot be used on `to pcs` stage"
115 );
116
117 if let (Some(a_curves), Some(clut)) = (self.a_curves.as_ref(), self.clut.as_ref()) {
121 if !clut.is_empty() {
122 let curve0 = &a_curves[0];
123 let curve1 = &a_curves[1];
124 let curve2 = &a_curves[2];
125 for (src, dst) in input.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
126 let b0 = lut_interp_linear_float(src[0].as_() * norm_value, curve0);
127 let b1 = lut_interp_linear_float(src[1].as_() * norm_value, curve1);
128 let b2 = lut_interp_linear_float(src[2].as_() * norm_value, curve2);
129 let interpolated = fetch(b0, b1, b2);
130 dst[0] = interpolated.v[0];
131 dst[1] = interpolated.v[1];
132 dst[2] = interpolated.v[2];
133 }
134 } else {
135 for (src, dst) in input.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
136 dst[0] = src[0].as_() * norm_value;
137 dst[1] = src[1].as_() * norm_value;
138 dst[2] = src[2].as_() * norm_value;
139 }
140 }
141 } else {
142 for (src, dst) in input.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
143 dst[0] = src[0].as_() * norm_value;
144 dst[1] = src[1].as_() * norm_value;
145 dst[2] = src[2].as_() * norm_value;
146 }
147 }
148
149 if let Some(m_curves) = self.m_curves.as_ref() {
152 self.execute_simple_curves(dst, m_curves);
153 self.execute_matrix_stage(dst);
154 }
155
156 if let Some(b_curves) = &self.b_curves.as_ref() {
158 self.execute_simple_curves(dst, b_curves);
159 }
160
161 Ok(())
162 }
163}
164
165impl<T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync>
166 KatanaInitialStage<f32, T> for Multidimensional3x3<T>
167{
168 fn to_pcs(&self, input: &[T]) -> Result<Vec<f32>, CmsError> {
169 if input.len() % 3 != 0 {
170 return Err(CmsError::LaneMultipleOfChannels);
171 }
172 let fixed_new_clut = Vec::new();
173 let new_clut = self.clut.as_ref().unwrap_or(&fixed_new_clut);
174 let lut = Cube::new_cube(new_clut, self.grid_size);
175
176 let mut new_dst = vec![0f32; input.len()];
177
178 if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
180 self.to_pcs_impl(input, &mut new_dst, |x, y, z| lut.trilinear_vec3(x, y, z))?;
181 return Ok(new_dst);
182 }
183
184 match self.options.interpolation_method {
185 #[cfg(feature = "options")]
186 InterpolationMethod::Tetrahedral => {
187 self.to_pcs_impl(input, &mut new_dst, |x, y, z| lut.tetra_vec3(x, y, z))?;
188 }
189 #[cfg(feature = "options")]
190 InterpolationMethod::Pyramid => {
191 self.to_pcs_impl(input, &mut new_dst, |x, y, z| lut.pyramid_vec3(x, y, z))?;
192 }
193 #[cfg(feature = "options")]
194 InterpolationMethod::Prism => {
195 self.to_pcs_impl(input, &mut new_dst, |x, y, z| lut.prism_vec3(x, y, z))?;
196 }
197 InterpolationMethod::Linear => {
198 self.to_pcs_impl(input, &mut new_dst, |x, y, z| lut.trilinear_vec3(x, y, z))?;
199 }
200 }
201 Ok(new_dst)
202 }
203}
204
205impl<T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync>
206 Multidimensional3x3<T>
207where
208 f32: AsPrimitive<T>,
209{
210 fn to_output_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
211 &self,
212 src: &mut [f32],
213 dst: &mut [T],
214 fetch: Fetch,
215 ) -> Result<(), CmsError> {
216 let norm_value = if T::FINITE {
217 ((1u32 << self.bit_depth) - 1) as f32
218 } else {
219 1.0
220 };
221 assert_eq!(
222 self.direction,
223 MultidimensionalDirection::PcsToDevice,
224 "Device to PCS cannot be used on `to output` stage"
225 );
226
227 if let Some(b_curves) = &self.b_curves.as_ref() {
228 self.execute_simple_curves(src, b_curves);
229 }
230
231 if let Some(m_curves) = self.m_curves.as_ref() {
234 self.execute_matrix_stage(src);
235 self.execute_simple_curves(src, m_curves);
236 }
237
238 if let (Some(a_curves), Some(clut)) = (self.a_curves.as_ref(), self.clut.as_ref()) {
239 if !clut.is_empty() {
240 let curve0 = &a_curves[0];
241 let curve1 = &a_curves[1];
242 let curve2 = &a_curves[2];
243 for (src, dst) in src.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
244 let b0 = lut_interp_linear_float(src[0], curve0);
245 let b1 = lut_interp_linear_float(src[1], curve1);
246 let b2 = lut_interp_linear_float(src[2], curve2);
247 let interpolated = fetch(b0, b1, b2);
248 if T::FINITE {
249 dst[0] = (interpolated.v[0] * norm_value)
250 .round()
251 .max(0.0)
252 .min(norm_value)
253 .as_();
254 dst[1] = (interpolated.v[1] * norm_value)
255 .round()
256 .max(0.0)
257 .min(norm_value)
258 .as_();
259 dst[2] = (interpolated.v[2] * norm_value)
260 .round()
261 .max(0.0)
262 .min(norm_value)
263 .as_();
264 } else {
265 dst[0] = interpolated.v[0].as_();
266 dst[1] = interpolated.v[1].as_();
267 dst[2] = interpolated.v[2].as_();
268 }
269 }
270 } else {
271 for (src, dst) in src.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
272 if T::FINITE {
273 dst[0] = (src[0] * norm_value).round().max(0.0).min(norm_value).as_();
274 dst[1] = (src[1] * norm_value).round().max(0.0).min(norm_value).as_();
275 dst[2] = (src[2] * norm_value).round().max(0.0).min(norm_value).as_();
276 } else {
277 dst[0] = src[0].as_();
278 dst[1] = src[1].as_();
279 dst[2] = src[2].as_();
280 }
281 }
282 }
283 } else {
284 for (src, dst) in src.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
285 if T::FINITE {
286 dst[0] = (src[0] * norm_value).round().max(0.0).min(norm_value).as_();
287 dst[1] = (src[1] * norm_value).round().max(0.0).min(norm_value).as_();
288 dst[2] = (src[2] * norm_value).round().max(0.0).min(norm_value).as_();
289 } else {
290 dst[0] = src[0].as_();
291 dst[1] = src[1].as_();
292 dst[2] = src[2].as_();
293 }
294 }
295 }
296
297 Ok(())
298 }
299}
300
301impl<T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync>
302 KatanaFinalStage<f32, T> for Multidimensional3x3<T>
303where
304 f32: AsPrimitive<T>,
305{
306 fn to_output(&self, src: &mut [f32], dst: &mut [T]) -> Result<(), CmsError> {
307 if src.len() % 3 != 0 {
308 return Err(CmsError::LaneMultipleOfChannels);
309 }
310 if dst.len() % 3 != 0 {
311 return Err(CmsError::LaneMultipleOfChannels);
312 }
313 if src.len() != dst.len() {
314 return Err(CmsError::LaneSizeMismatch);
315 }
316 let fixed_new_clut = Vec::new();
317 let new_clut = self.clut.as_ref().unwrap_or(&fixed_new_clut);
318 let lut = Cube::new_cube(new_clut, self.grid_size);
319
320 if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
322 return self.to_output_impl(src, dst, |x, y, z| lut.trilinear_vec3(x, y, z));
323 }
324
325 match self.options.interpolation_method {
326 #[cfg(feature = "options")]
327 InterpolationMethod::Tetrahedral => {
328 self.to_output_impl(src, dst, |x, y, z| lut.tetra_vec3(x, y, z))?;
329 }
330 #[cfg(feature = "options")]
331 InterpolationMethod::Pyramid => {
332 self.to_output_impl(src, dst, |x, y, z| lut.pyramid_vec3(x, y, z))?;
333 }
334 #[cfg(feature = "options")]
335 InterpolationMethod::Prism => {
336 self.to_output_impl(src, dst, |x, y, z| lut.prism_vec3(x, y, z))?;
337 }
338 InterpolationMethod::Linear => {
339 self.to_output_impl(src, dst, |x, y, z| lut.trilinear_vec3(x, y, z))?;
340 }
341 }
342 Ok(())
343 }
344}
345
346fn make_multidimensional_3x3<
347 T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
348>(
349 mab: &LutMultidimensionalType,
350 options: TransformOptions,
351 pcs: DataColorSpace,
352 direction: MultidimensionalDirection,
353 bit_depth: usize,
354) -> Result<Multidimensional3x3<T>, CmsError> {
355 if mab.num_input_channels != 3 && mab.num_output_channels != 3 {
356 return Err(CmsError::UnsupportedProfileConnection);
357 }
358 if mab.b_curves.is_empty() || mab.b_curves.len() != 3 {
359 return Err(CmsError::InvalidAtoBLut);
360 }
361
362 let grid_size = [mab.grid_points[0], mab.grid_points[1], mab.grid_points[2]];
363
364 let clut: Option<Vec<f32>> = if mab.a_curves.len() == 3 && mab.clut.is_some() {
365 let clut = mab.clut.as_ref().map(|x| x.to_clut_f32()).unwrap();
366 let lut_grid = (mab.grid_points[0] as usize)
367 .safe_mul(mab.grid_points[1] as usize)?
368 .safe_mul(mab.grid_points[2] as usize)?
369 .safe_mul(mab.num_output_channels as usize)?;
370 if clut.len() != lut_grid {
371 return Err(CmsError::MalformedCurveLutTable(MalformedSize {
372 size: clut.len(),
373 expected: lut_grid,
374 }));
375 }
376 Some(clut)
377 } else {
378 None
379 };
380
381 let a_curves: Option<Box<[Vec<f32>; 3]>> = if mab.a_curves.len() == 3 && mab.clut.is_some() {
382 let mut arr = Box::<[Vec<f32>; 3]>::default();
383 for (a_curve, dst) in mab.a_curves.iter().zip(arr.iter_mut()) {
384 *dst = a_curve.to_clut()?;
385 }
386 Some(arr)
387 } else {
388 None
389 };
390
391 let b_curves: Option<Box<[Vec<f32>; 3]>> = if mab.b_curves.len() == 3 {
392 let mut arr = Box::<[Vec<f32>; 3]>::default();
393 let all_curves_linear = mab.b_curves.iter().all(|curve| curve.is_linear());
394 if all_curves_linear {
395 None
396 } else {
397 for (c_curve, dst) in mab.b_curves.iter().zip(arr.iter_mut()) {
398 *dst = c_curve.to_clut()?;
399 }
400 Some(arr)
401 }
402 } else {
403 return Err(CmsError::InvalidAtoBLut);
404 };
405
406 let matrix = mab.matrix.to_f32();
407
408 let m_curves: Option<Box<[Vec<f32>; 3]>> = if mab.m_curves.len() == 3 {
409 let all_curves_linear = mab.m_curves.iter().all(|curve| curve.is_linear());
410 if !all_curves_linear
411 || !mab.matrix.test_equality(Matrix3d::IDENTITY)
412 || mab.bias.ne(&Vector3d::default())
413 {
414 let mut arr = Box::<[Vec<f32>; 3]>::default();
415 for (curve, dst) in mab.m_curves.iter().zip(arr.iter_mut()) {
416 *dst = curve.to_clut()?;
417 }
418 Some(arr)
419 } else {
420 None
421 }
422 } else {
423 None
424 };
425
426 let bias = mab.bias.cast();
427
428 let transform = Multidimensional3x3::<T> {
429 a_curves,
430 b_curves,
431 m_curves,
432 matrix,
433 direction,
434 options,
435 clut,
436 pcs,
437 grid_size,
438 bias,
439 _phantom: PhantomData,
440 bit_depth,
441 };
442
443 Ok(transform)
444}
445
446pub(crate) fn multi_dimensional_3x3_to_pcs<
447 T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
448>(
449 mab: &LutMultidimensionalType,
450 options: TransformOptions,
451 pcs: DataColorSpace,
452 bit_depth: usize,
453) -> Result<Box<dyn KatanaInitialStage<f32, T> + Send + Sync>, CmsError> {
454 let transform = make_multidimensional_3x3::<T>(
455 mab,
456 options,
457 pcs,
458 MultidimensionalDirection::DeviceToPcs,
459 bit_depth,
460 )?;
461 Ok(Box::new(transform))
462}
463
464pub(crate) fn multi_dimensional_3x3_to_device<
465 T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
466>(
467 mab: &LutMultidimensionalType,
468 options: TransformOptions,
469 pcs: DataColorSpace,
470 bit_depth: usize,
471) -> Result<Box<dyn KatanaFinalStage<f32, T> + Send + Sync>, CmsError>
472where
473 f32: AsPrimitive<T>,
474{
475 let transform = make_multidimensional_3x3::<T>(
476 mab,
477 options,
478 pcs,
479 MultidimensionalDirection::PcsToDevice,
480 bit_depth,
481 )?;
482 Ok(Box::new(transform))
483}