moxcms/conversions/katana/
md3x3.rs

1/*
2 * // Copyright (c) Radzivon Bartoshyk 6/2025. All rights reserved.
3 * //
4 * // Redistribution and use in source and binary forms, with or without modification,
5 * // are permitted provided that the following conditions are met:
6 * //
7 * // 1.  Redistributions of source code must retain the above copyright notice, this
8 * // list of conditions and the following disclaimer.
9 * //
10 * // 2.  Redistributions in binary form must reproduce the above copyright notice,
11 * // this list of conditions and the following disclaimer in the documentation
12 * // and/or other materials provided with the distribution.
13 * //
14 * // 3.  Neither the name of the copyright holder nor the names of its
15 * // contributors may be used to endorse or promote products derived from
16 * // this software without specific prior written permission.
17 * //
18 * // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 * // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 * // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 * // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 * // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 * // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 * // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 * // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29use crate::conversions::katana::{KatanaFinalStage, KatanaInitialStage};
30use crate::mlaf::mlaf;
31use crate::safe_math::SafeMul;
32use crate::trc::lut_interp_linear_float;
33use crate::{
34    CmsError, Cube, DataColorSpace, InterpolationMethod, LutMultidimensionalType, MalformedSize,
35    Matrix3d, Matrix3f, PointeeSizeExpressible, TransformOptions, Vector3d, Vector3f,
36};
37use num_traits::AsPrimitive;
38use std::marker::PhantomData;
39
40#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
41pub(crate) enum MultidimensionalDirection {
42    DeviceToPcs,
43    PcsToDevice,
44}
45
46struct Multidimensional3x3<
47    T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
48> {
49    a_curves: Option<Box<[Vec<f32>; 3]>>,
50    m_curves: Option<Box<[Vec<f32>; 3]>>,
51    b_curves: Option<Box<[Vec<f32>; 3]>>,
52    clut: Option<Vec<f32>>,
53    matrix: Matrix3f,
54    bias: Vector3f,
55    direction: MultidimensionalDirection,
56    options: TransformOptions,
57    pcs: DataColorSpace,
58    grid_size: [u8; 3],
59    _phantom: PhantomData<T>,
60    bit_depth: usize,
61}
62
63impl<T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync>
64    Multidimensional3x3<T>
65{
66    fn execute_matrix_stage(&self, dst: &mut [f32]) {
67        let m = self.matrix;
68        let b = self.bias;
69
70        if !m.test_equality(Matrix3f::IDENTITY) || !b.eq(&Vector3f::default()) {
71            for dst in dst.chunks_exact_mut(3) {
72                let x = dst[0];
73                let y = dst[1];
74                let z = dst[2];
75                dst[0] = mlaf(mlaf(mlaf(b.v[0], x, m.v[0][0]), y, m.v[0][1]), z, m.v[0][2]);
76                dst[1] = mlaf(mlaf(mlaf(b.v[1], x, m.v[1][0]), y, m.v[1][1]), z, m.v[1][2]);
77                dst[2] = mlaf(mlaf(mlaf(b.v[2], x, m.v[2][0]), y, m.v[2][1]), z, m.v[2][2]);
78            }
79        }
80    }
81
82    fn execute_simple_curves(&self, dst: &mut [f32], curves: &[Vec<f32>; 3]) {
83        let curve0 = &curves[0];
84        let curve1 = &curves[1];
85        let curve2 = &curves[2];
86
87        for dst in dst.chunks_exact_mut(3) {
88            let a0 = dst[0];
89            let a1 = dst[1];
90            let a2 = dst[2];
91            let b0 = lut_interp_linear_float(a0, curve0);
92            let b1 = lut_interp_linear_float(a1, curve1);
93            let b2 = lut_interp_linear_float(a2, curve2);
94            dst[0] = b0;
95            dst[1] = b1;
96            dst[2] = b2;
97        }
98    }
99
100    fn to_pcs_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
101        &self,
102        input: &[T],
103        dst: &mut [f32],
104        fetch: Fetch,
105    ) -> Result<(), CmsError> {
106        let norm_value = if T::FINITE {
107            1.0 / ((1u32 << self.bit_depth) - 1) as f32
108        } else {
109            1.0
110        };
111        assert_eq!(
112            self.direction,
113            MultidimensionalDirection::DeviceToPcs,
114            "PCS to device cannot be used on `to pcs` stage"
115        );
116
117        // A -> B
118        // OR B - A A - curves stage
119
120        if let (Some(a_curves), Some(clut)) = (self.a_curves.as_ref(), self.clut.as_ref()) {
121            if !clut.is_empty() {
122                let curve0 = &a_curves[0];
123                let curve1 = &a_curves[1];
124                let curve2 = &a_curves[2];
125                for (src, dst) in input.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
126                    let b0 = lut_interp_linear_float(src[0].as_() * norm_value, curve0);
127                    let b1 = lut_interp_linear_float(src[1].as_() * norm_value, curve1);
128                    let b2 = lut_interp_linear_float(src[2].as_() * norm_value, curve2);
129                    let interpolated = fetch(b0, b1, b2);
130                    dst[0] = interpolated.v[0];
131                    dst[1] = interpolated.v[1];
132                    dst[2] = interpolated.v[2];
133                }
134            } else {
135                for (src, dst) in input.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
136                    dst[0] = src[0].as_() * norm_value;
137                    dst[1] = src[1].as_() * norm_value;
138                    dst[2] = src[2].as_() * norm_value;
139                }
140            }
141        } else {
142            for (src, dst) in input.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
143                dst[0] = src[0].as_() * norm_value;
144                dst[1] = src[1].as_() * norm_value;
145                dst[2] = src[2].as_() * norm_value;
146            }
147        }
148
149        // Matrix stage
150
151        if let Some(m_curves) = self.m_curves.as_ref() {
152            self.execute_simple_curves(dst, m_curves);
153            self.execute_matrix_stage(dst);
154        }
155
156        // B-curves is mandatory
157        if let Some(b_curves) = &self.b_curves.as_ref() {
158            self.execute_simple_curves(dst, b_curves);
159        }
160
161        Ok(())
162    }
163}
164
165impl<T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync>
166    KatanaInitialStage<f32, T> for Multidimensional3x3<T>
167{
168    fn to_pcs(&self, input: &[T]) -> Result<Vec<f32>, CmsError> {
169        if input.len() % 3 != 0 {
170            return Err(CmsError::LaneMultipleOfChannels);
171        }
172        let fixed_new_clut = Vec::new();
173        let new_clut = self.clut.as_ref().unwrap_or(&fixed_new_clut);
174        let lut = Cube::new_cube(new_clut, self.grid_size);
175
176        let mut new_dst = vec![0f32; input.len()];
177
178        // If PCS is LAB then linear interpolation should be used
179        if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
180            self.to_pcs_impl(input, &mut new_dst, |x, y, z| lut.trilinear_vec3(x, y, z))?;
181            return Ok(new_dst);
182        }
183
184        match self.options.interpolation_method {
185            #[cfg(feature = "options")]
186            InterpolationMethod::Tetrahedral => {
187                self.to_pcs_impl(input, &mut new_dst, |x, y, z| lut.tetra_vec3(x, y, z))?;
188            }
189            #[cfg(feature = "options")]
190            InterpolationMethod::Pyramid => {
191                self.to_pcs_impl(input, &mut new_dst, |x, y, z| lut.pyramid_vec3(x, y, z))?;
192            }
193            #[cfg(feature = "options")]
194            InterpolationMethod::Prism => {
195                self.to_pcs_impl(input, &mut new_dst, |x, y, z| lut.prism_vec3(x, y, z))?;
196            }
197            InterpolationMethod::Linear => {
198                self.to_pcs_impl(input, &mut new_dst, |x, y, z| lut.trilinear_vec3(x, y, z))?;
199            }
200        }
201        Ok(new_dst)
202    }
203}
204
205impl<T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync>
206    Multidimensional3x3<T>
207where
208    f32: AsPrimitive<T>,
209{
210    fn to_output_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
211        &self,
212        src: &mut [f32],
213        dst: &mut [T],
214        fetch: Fetch,
215    ) -> Result<(), CmsError> {
216        let norm_value = if T::FINITE {
217            ((1u32 << self.bit_depth) - 1) as f32
218        } else {
219            1.0
220        };
221        assert_eq!(
222            self.direction,
223            MultidimensionalDirection::PcsToDevice,
224            "Device to PCS cannot be used on `to output` stage"
225        );
226
227        if let Some(b_curves) = &self.b_curves.as_ref() {
228            self.execute_simple_curves(src, b_curves);
229        }
230
231        // Matrix stage
232
233        if let Some(m_curves) = self.m_curves.as_ref() {
234            self.execute_matrix_stage(src);
235            self.execute_simple_curves(src, m_curves);
236        }
237
238        if let (Some(a_curves), Some(clut)) = (self.a_curves.as_ref(), self.clut.as_ref()) {
239            if !clut.is_empty() {
240                let curve0 = &a_curves[0];
241                let curve1 = &a_curves[1];
242                let curve2 = &a_curves[2];
243                for (src, dst) in src.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
244                    let b0 = lut_interp_linear_float(src[0], curve0);
245                    let b1 = lut_interp_linear_float(src[1], curve1);
246                    let b2 = lut_interp_linear_float(src[2], curve2);
247                    let interpolated = fetch(b0, b1, b2);
248                    if T::FINITE {
249                        dst[0] = (interpolated.v[0] * norm_value)
250                            .round()
251                            .max(0.0)
252                            .min(norm_value)
253                            .as_();
254                        dst[1] = (interpolated.v[1] * norm_value)
255                            .round()
256                            .max(0.0)
257                            .min(norm_value)
258                            .as_();
259                        dst[2] = (interpolated.v[2] * norm_value)
260                            .round()
261                            .max(0.0)
262                            .min(norm_value)
263                            .as_();
264                    } else {
265                        dst[0] = interpolated.v[0].as_();
266                        dst[1] = interpolated.v[1].as_();
267                        dst[2] = interpolated.v[2].as_();
268                    }
269                }
270            } else {
271                for (src, dst) in src.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
272                    if T::FINITE {
273                        dst[0] = (src[0] * norm_value).round().max(0.0).min(norm_value).as_();
274                        dst[1] = (src[1] * norm_value).round().max(0.0).min(norm_value).as_();
275                        dst[2] = (src[2] * norm_value).round().max(0.0).min(norm_value).as_();
276                    } else {
277                        dst[0] = src[0].as_();
278                        dst[1] = src[1].as_();
279                        dst[2] = src[2].as_();
280                    }
281                }
282            }
283        } else {
284            for (src, dst) in src.chunks_exact(3).zip(dst.chunks_exact_mut(3)) {
285                if T::FINITE {
286                    dst[0] = (src[0] * norm_value).round().max(0.0).min(norm_value).as_();
287                    dst[1] = (src[1] * norm_value).round().max(0.0).min(norm_value).as_();
288                    dst[2] = (src[2] * norm_value).round().max(0.0).min(norm_value).as_();
289                } else {
290                    dst[0] = src[0].as_();
291                    dst[1] = src[1].as_();
292                    dst[2] = src[2].as_();
293                }
294            }
295        }
296
297        Ok(())
298    }
299}
300
301impl<T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync>
302    KatanaFinalStage<f32, T> for Multidimensional3x3<T>
303where
304    f32: AsPrimitive<T>,
305{
306    fn to_output(&self, src: &mut [f32], dst: &mut [T]) -> Result<(), CmsError> {
307        if src.len() % 3 != 0 {
308            return Err(CmsError::LaneMultipleOfChannels);
309        }
310        if dst.len() % 3 != 0 {
311            return Err(CmsError::LaneMultipleOfChannels);
312        }
313        if src.len() != dst.len() {
314            return Err(CmsError::LaneSizeMismatch);
315        }
316        let fixed_new_clut = Vec::new();
317        let new_clut = self.clut.as_ref().unwrap_or(&fixed_new_clut);
318        let lut = Cube::new_cube(new_clut, self.grid_size);
319
320        // If PCS is LAB then linear interpolation should be used
321        if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
322            return self.to_output_impl(src, dst, |x, y, z| lut.trilinear_vec3(x, y, z));
323        }
324
325        match self.options.interpolation_method {
326            #[cfg(feature = "options")]
327            InterpolationMethod::Tetrahedral => {
328                self.to_output_impl(src, dst, |x, y, z| lut.tetra_vec3(x, y, z))?;
329            }
330            #[cfg(feature = "options")]
331            InterpolationMethod::Pyramid => {
332                self.to_output_impl(src, dst, |x, y, z| lut.pyramid_vec3(x, y, z))?;
333            }
334            #[cfg(feature = "options")]
335            InterpolationMethod::Prism => {
336                self.to_output_impl(src, dst, |x, y, z| lut.prism_vec3(x, y, z))?;
337            }
338            InterpolationMethod::Linear => {
339                self.to_output_impl(src, dst, |x, y, z| lut.trilinear_vec3(x, y, z))?;
340            }
341        }
342        Ok(())
343    }
344}
345
346fn make_multidimensional_3x3<
347    T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
348>(
349    mab: &LutMultidimensionalType,
350    options: TransformOptions,
351    pcs: DataColorSpace,
352    direction: MultidimensionalDirection,
353    bit_depth: usize,
354) -> Result<Multidimensional3x3<T>, CmsError> {
355    if mab.num_input_channels != 3 && mab.num_output_channels != 3 {
356        return Err(CmsError::UnsupportedProfileConnection);
357    }
358    if mab.b_curves.is_empty() || mab.b_curves.len() != 3 {
359        return Err(CmsError::InvalidAtoBLut);
360    }
361
362    let grid_size = [mab.grid_points[0], mab.grid_points[1], mab.grid_points[2]];
363
364    let clut: Option<Vec<f32>> = if mab.a_curves.len() == 3 && mab.clut.is_some() {
365        let clut = mab.clut.as_ref().map(|x| x.to_clut_f32()).unwrap();
366        let lut_grid = (mab.grid_points[0] as usize)
367            .safe_mul(mab.grid_points[1] as usize)?
368            .safe_mul(mab.grid_points[2] as usize)?
369            .safe_mul(mab.num_output_channels as usize)?;
370        if clut.len() != lut_grid {
371            return Err(CmsError::MalformedCurveLutTable(MalformedSize {
372                size: clut.len(),
373                expected: lut_grid,
374            }));
375        }
376        Some(clut)
377    } else {
378        None
379    };
380
381    let a_curves: Option<Box<[Vec<f32>; 3]>> = if mab.a_curves.len() == 3 && mab.clut.is_some() {
382        let mut arr = Box::<[Vec<f32>; 3]>::default();
383        for (a_curve, dst) in mab.a_curves.iter().zip(arr.iter_mut()) {
384            *dst = a_curve.to_clut()?;
385        }
386        Some(arr)
387    } else {
388        None
389    };
390
391    let b_curves: Option<Box<[Vec<f32>; 3]>> = if mab.b_curves.len() == 3 {
392        let mut arr = Box::<[Vec<f32>; 3]>::default();
393        let all_curves_linear = mab.b_curves.iter().all(|curve| curve.is_linear());
394        if all_curves_linear {
395            None
396        } else {
397            for (c_curve, dst) in mab.b_curves.iter().zip(arr.iter_mut()) {
398                *dst = c_curve.to_clut()?;
399            }
400            Some(arr)
401        }
402    } else {
403        return Err(CmsError::InvalidAtoBLut);
404    };
405
406    let matrix = mab.matrix.to_f32();
407
408    let m_curves: Option<Box<[Vec<f32>; 3]>> = if mab.m_curves.len() == 3 {
409        let all_curves_linear = mab.m_curves.iter().all(|curve| curve.is_linear());
410        if !all_curves_linear
411            || !mab.matrix.test_equality(Matrix3d::IDENTITY)
412            || mab.bias.ne(&Vector3d::default())
413        {
414            let mut arr = Box::<[Vec<f32>; 3]>::default();
415            for (curve, dst) in mab.m_curves.iter().zip(arr.iter_mut()) {
416                *dst = curve.to_clut()?;
417            }
418            Some(arr)
419        } else {
420            None
421        }
422    } else {
423        None
424    };
425
426    let bias = mab.bias.cast();
427
428    let transform = Multidimensional3x3::<T> {
429        a_curves,
430        b_curves,
431        m_curves,
432        matrix,
433        direction,
434        options,
435        clut,
436        pcs,
437        grid_size,
438        bias,
439        _phantom: PhantomData,
440        bit_depth,
441    };
442
443    Ok(transform)
444}
445
446pub(crate) fn multi_dimensional_3x3_to_pcs<
447    T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
448>(
449    mab: &LutMultidimensionalType,
450    options: TransformOptions,
451    pcs: DataColorSpace,
452    bit_depth: usize,
453) -> Result<Box<dyn KatanaInitialStage<f32, T> + Send + Sync>, CmsError> {
454    let transform = make_multidimensional_3x3::<T>(
455        mab,
456        options,
457        pcs,
458        MultidimensionalDirection::DeviceToPcs,
459        bit_depth,
460    )?;
461    Ok(Box::new(transform))
462}
463
464pub(crate) fn multi_dimensional_3x3_to_device<
465    T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
466>(
467    mab: &LutMultidimensionalType,
468    options: TransformOptions,
469    pcs: DataColorSpace,
470    bit_depth: usize,
471) -> Result<Box<dyn KatanaFinalStage<f32, T> + Send + Sync>, CmsError>
472where
473    f32: AsPrimitive<T>,
474{
475    let transform = make_multidimensional_3x3::<T>(
476        mab,
477        options,
478        pcs,
479        MultidimensionalDirection::PcsToDevice,
480        bit_depth,
481    )?;
482    Ok(Box::new(transform))
483}