moxcms/conversions/
lut3x3.rs

1/*
2 * // Copyright (c) Radzivon Bartoshyk 3/2025. All rights reserved.
3 * //
4 * // Redistribution and use in source and binary forms, with or without modification,
5 * // are permitted provided that the following conditions are met:
6 * //
7 * // 1.  Redistributions of source code must retain the above copyright notice, this
8 * // list of conditions and the following disclaimer.
9 * //
10 * // 2.  Redistributions in binary form must reproduce the above copyright notice,
11 * // this list of conditions and the following disclaimer in the documentation
12 * // and/or other materials provided with the distribution.
13 * //
14 * // 3.  Neither the name of the copyright holder nor the names of its
15 * // contributors may be used to endorse or promote products derived from
16 * // this software without specific prior written permission.
17 * //
18 * // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 * // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 * // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 * // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 * // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 * // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 * // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 * // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29use crate::conversions::katana::{KatanaFinalStage, KatanaInitialStage};
30use crate::err::MalformedSize;
31use crate::profile::LutDataType;
32use crate::safe_math::{SafeMul, SafePowi};
33use crate::trc::lut_interp_linear_float;
34use crate::{
35    CmsError, Cube, DataColorSpace, InterpolationMethod, PointeeSizeExpressible, Stage,
36    TransformOptions, Vector3f,
37};
38use num_traits::AsPrimitive;
39
40#[derive(Default)]
41struct Lut3x3 {
42    input: [Vec<f32>; 3],
43    clut: Vec<f32>,
44    grid_size: u8,
45    gamma: [Vec<f32>; 3],
46    interpolation_method: InterpolationMethod,
47    pcs: DataColorSpace,
48}
49
50#[derive(Default)]
51struct KatanaLut3x3<T: Copy + Default> {
52    input: [Vec<f32>; 3],
53    clut: Vec<f32>,
54    grid_size: u8,
55    gamma: [Vec<f32>; 3],
56    interpolation_method: InterpolationMethod,
57    pcs: DataColorSpace,
58    _phantom: std::marker::PhantomData<T>,
59    bit_depth: usize,
60}
61
62fn make_lut_3x3(
63    lut: &LutDataType,
64    options: TransformOptions,
65    pcs: DataColorSpace,
66) -> Result<Lut3x3, CmsError> {
67    let clut_length: usize = (lut.num_clut_grid_points as usize)
68        .safe_powi(lut.num_input_channels as u32)?
69        .safe_mul(lut.num_output_channels as usize)?;
70
71    let lin_table = lut.input_table.to_clut_f32();
72
73    if lin_table.len() < lut.num_input_table_entries as usize * 3 {
74        return Err(CmsError::MalformedCurveLutTable(MalformedSize {
75            size: lin_table.len(),
76            expected: lut.num_input_table_entries as usize * 3,
77        }));
78    }
79
80    let lin_curve0 = lin_table[..lut.num_input_table_entries as usize].to_vec();
81    let lin_curve1 = lin_table
82        [lut.num_input_table_entries as usize..lut.num_input_table_entries as usize * 2]
83        .to_vec();
84    let lin_curve2 = lin_table
85        [lut.num_input_table_entries as usize * 2..lut.num_input_table_entries as usize * 3]
86        .to_vec();
87
88    let clut_table = lut.clut_table.to_clut_f32();
89    if clut_table.len() != clut_length {
90        return Err(CmsError::MalformedClut(MalformedSize {
91            size: clut_table.len(),
92            expected: clut_length,
93        }));
94    }
95
96    let gamma_curves = lut.output_table.to_clut_f32();
97
98    if gamma_curves.len() < lut.num_output_table_entries as usize * 3 {
99        return Err(CmsError::MalformedCurveLutTable(MalformedSize {
100            size: gamma_curves.len(),
101            expected: lut.num_output_table_entries as usize * 3,
102        }));
103    }
104
105    let gamma_curve0 = gamma_curves[..lut.num_output_table_entries as usize].to_vec();
106    let gamma_curve1 = gamma_curves
107        [lut.num_output_table_entries as usize..lut.num_output_table_entries as usize * 2]
108        .to_vec();
109    let gamma_curve2 = gamma_curves
110        [lut.num_output_table_entries as usize * 2..lut.num_output_table_entries as usize * 3]
111        .to_vec();
112
113    let transform = Lut3x3 {
114        input: [lin_curve0, lin_curve1, lin_curve2],
115        gamma: [gamma_curve0, gamma_curve1, gamma_curve2],
116        interpolation_method: options.interpolation_method,
117        clut: clut_table,
118        grid_size: lut.num_clut_grid_points,
119        pcs,
120    };
121
122    Ok(transform)
123}
124
125fn stage_lut_3x3(
126    lut: &LutDataType,
127    options: TransformOptions,
128    pcs: DataColorSpace,
129) -> Result<Box<dyn Stage>, CmsError> {
130    let lut = make_lut_3x3(lut, options, pcs)?;
131
132    let transform = Lut3x3 {
133        input: lut.input,
134        gamma: lut.gamma,
135        interpolation_method: lut.interpolation_method,
136        clut: lut.clut,
137        grid_size: lut.grid_size,
138        pcs: lut.pcs,
139    };
140
141    Ok(Box::new(transform))
142}
143
144pub(crate) fn katana_input_stage_lut_3x3<
145    T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
146>(
147    lut: &LutDataType,
148    options: TransformOptions,
149    pcs: DataColorSpace,
150    bit_depth: usize,
151) -> Result<Box<dyn KatanaInitialStage<f32, T> + Send + Sync>, CmsError>
152where
153    f32: AsPrimitive<T>,
154{
155    let lut = make_lut_3x3(lut, options, pcs)?;
156
157    let transform = KatanaLut3x3::<T> {
158        input: lut.input,
159        gamma: lut.gamma,
160        interpolation_method: lut.interpolation_method,
161        clut: lut.clut,
162        grid_size: lut.grid_size,
163        pcs: lut.pcs,
164        _phantom: std::marker::PhantomData,
165        bit_depth,
166    };
167
168    Ok(Box::new(transform))
169}
170
171pub(crate) fn katana_output_stage_lut_3x3<
172    T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
173>(
174    lut: &LutDataType,
175    options: TransformOptions,
176    pcs: DataColorSpace,
177    bit_depth: usize,
178) -> Result<Box<dyn KatanaFinalStage<f32, T> + Send + Sync>, CmsError>
179where
180    f32: AsPrimitive<T>,
181{
182    let lut = make_lut_3x3(lut, options, pcs)?;
183
184    let transform = KatanaLut3x3::<T> {
185        input: lut.input,
186        gamma: lut.gamma,
187        interpolation_method: lut.interpolation_method,
188        clut: lut.clut,
189        grid_size: lut.grid_size,
190        pcs: lut.pcs,
191        _phantom: std::marker::PhantomData,
192        bit_depth,
193    };
194
195    Ok(Box::new(transform))
196}
197
198impl Lut3x3 {
199    fn transform_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
200        &self,
201        src: &[f32],
202        dst: &mut [f32],
203        fetch: Fetch,
204    ) -> Result<(), CmsError> {
205        let linearization_0 = &self.input[0];
206        let linearization_1 = &self.input[1];
207        let linearization_2 = &self.input[2];
208        for (dest, src) in dst.chunks_exact_mut(3).zip(src.chunks_exact(3)) {
209            debug_assert!(self.grid_size as i32 >= 1);
210            let linear_x = lut_interp_linear_float(src[0], linearization_0);
211            let linear_y = lut_interp_linear_float(src[1], linearization_1);
212            let linear_z = lut_interp_linear_float(src[2], linearization_2);
213
214            let clut = fetch(linear_x, linear_y, linear_z);
215
216            let pcs_x = lut_interp_linear_float(clut.v[0], &self.gamma[0]);
217            let pcs_y = lut_interp_linear_float(clut.v[1], &self.gamma[1]);
218            let pcs_z = lut_interp_linear_float(clut.v[2], &self.gamma[2]);
219            dest[0] = pcs_x;
220            dest[1] = pcs_y;
221            dest[2] = pcs_z;
222        }
223        Ok(())
224    }
225}
226
227impl Stage for Lut3x3 {
228    fn transform(&self, src: &[f32], dst: &mut [f32]) -> Result<(), CmsError> {
229        let l_tbl = Cube::new(&self.clut, self.grid_size as usize);
230
231        // If PCS is LAB then linear interpolation should be used
232        if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
233            return self.transform_impl(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z));
234        }
235
236        match self.interpolation_method {
237            #[cfg(feature = "options")]
238            InterpolationMethod::Tetrahedral => {
239                self.transform_impl(src, dst, |x, y, z| l_tbl.tetra_vec3(x, y, z))?;
240            }
241            #[cfg(feature = "options")]
242            InterpolationMethod::Pyramid => {
243                self.transform_impl(src, dst, |x, y, z| l_tbl.pyramid_vec3(x, y, z))?;
244            }
245            #[cfg(feature = "options")]
246            InterpolationMethod::Prism => {
247                self.transform_impl(src, dst, |x, y, z| l_tbl.prism_vec3(x, y, z))?;
248            }
249            InterpolationMethod::Linear => {
250                self.transform_impl(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z))?;
251            }
252        }
253        Ok(())
254    }
255}
256
257impl<T: Copy + Default + PointeeSizeExpressible + AsPrimitive<f32>> KatanaLut3x3<T>
258where
259    f32: AsPrimitive<T>,
260{
261    fn to_pcs_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
262        &self,
263        input: &[T],
264        fetch: Fetch,
265    ) -> Result<Vec<f32>, CmsError> {
266        if input.len() % 3 != 0 {
267            return Err(CmsError::LaneMultipleOfChannels);
268        }
269        let normalizing_value = if T::FINITE {
270            1.0 / ((1u32 << self.bit_depth) - 1) as f32
271        } else {
272            1.0
273        };
274        let mut dst = vec![0.; input.len()];
275        let linearization_0 = &self.input[0];
276        let linearization_1 = &self.input[1];
277        let linearization_2 = &self.input[2];
278        for (dest, src) in dst.chunks_exact_mut(3).zip(input.chunks_exact(3)) {
279            let linear_x =
280                lut_interp_linear_float(src[0].as_() * normalizing_value, linearization_0);
281            let linear_y =
282                lut_interp_linear_float(src[1].as_() * normalizing_value, linearization_1);
283            let linear_z =
284                lut_interp_linear_float(src[2].as_() * normalizing_value, linearization_2);
285
286            let clut = fetch(linear_x, linear_y, linear_z);
287
288            let pcs_x = lut_interp_linear_float(clut.v[0], &self.gamma[0]);
289            let pcs_y = lut_interp_linear_float(clut.v[1], &self.gamma[1]);
290            let pcs_z = lut_interp_linear_float(clut.v[2], &self.gamma[2]);
291            dest[0] = pcs_x;
292            dest[1] = pcs_y;
293            dest[2] = pcs_z;
294        }
295        Ok(dst)
296    }
297
298    fn to_output<Fetch: Fn(f32, f32, f32) -> Vector3f>(
299        &self,
300        src: &[f32],
301        dst: &mut [T],
302        fetch: Fetch,
303    ) -> Result<(), CmsError> {
304        if src.len() % 3 != 0 {
305            return Err(CmsError::LaneMultipleOfChannels);
306        }
307        if dst.len() % 3 != 0 {
308            return Err(CmsError::LaneMultipleOfChannels);
309        }
310        if dst.len() != src.len() {
311            return Err(CmsError::LaneSizeMismatch);
312        }
313        let norm_value = if T::FINITE {
314            ((1u32 << self.bit_depth) - 1) as f32
315        } else {
316            1.0
317        };
318
319        let linearization_0 = &self.input[0];
320        let linearization_1 = &self.input[1];
321        let linearization_2 = &self.input[2];
322        for (dest, src) in dst.chunks_exact_mut(3).zip(src.chunks_exact(3)) {
323            let linear_x = lut_interp_linear_float(src[0], linearization_0);
324            let linear_y = lut_interp_linear_float(src[1], linearization_1);
325            let linear_z = lut_interp_linear_float(src[2], linearization_2);
326
327            let clut = fetch(linear_x, linear_y, linear_z);
328
329            let pcs_x = lut_interp_linear_float(clut.v[0], &self.gamma[0]);
330            let pcs_y = lut_interp_linear_float(clut.v[1], &self.gamma[1]);
331            let pcs_z = lut_interp_linear_float(clut.v[2], &self.gamma[2]);
332
333            if T::FINITE {
334                dest[0] = (pcs_x * norm_value).round().max(0.0).min(norm_value).as_();
335                dest[1] = (pcs_y * norm_value).round().max(0.0).min(norm_value).as_();
336                dest[2] = (pcs_z * norm_value).round().max(0.0).min(norm_value).as_();
337            } else {
338                dest[0] = pcs_x.as_();
339                dest[1] = pcs_y.as_();
340                dest[2] = pcs_z.as_();
341            }
342        }
343        Ok(())
344    }
345}
346
347impl<T: Copy + Default + PointeeSizeExpressible + AsPrimitive<f32>> KatanaInitialStage<f32, T>
348    for KatanaLut3x3<T>
349where
350    f32: AsPrimitive<T>,
351{
352    fn to_pcs(&self, input: &[T]) -> Result<Vec<f32>, CmsError> {
353        let l_tbl = Cube::new(&self.clut, self.grid_size as usize);
354
355        // If PCS is LAB then linear interpolation should be used
356        if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
357            return self.to_pcs_impl(input, |x, y, z| l_tbl.trilinear_vec3(x, y, z));
358        }
359
360        match self.interpolation_method {
361            #[cfg(feature = "options")]
362            InterpolationMethod::Tetrahedral => {
363                self.to_pcs_impl(input, |x, y, z| l_tbl.tetra_vec3(x, y, z))
364            }
365            #[cfg(feature = "options")]
366            InterpolationMethod::Pyramid => {
367                self.to_pcs_impl(input, |x, y, z| l_tbl.pyramid_vec3(x, y, z))
368            }
369            #[cfg(feature = "options")]
370            InterpolationMethod::Prism => {
371                self.to_pcs_impl(input, |x, y, z| l_tbl.prism_vec3(x, y, z))
372            }
373            InterpolationMethod::Linear => {
374                self.to_pcs_impl(input, |x, y, z| l_tbl.trilinear_vec3(x, y, z))
375            }
376        }
377    }
378}
379
380impl<T: Copy + Default + PointeeSizeExpressible + AsPrimitive<f32>> KatanaFinalStage<f32, T>
381    for KatanaLut3x3<T>
382where
383    f32: AsPrimitive<T>,
384{
385    fn to_output(&self, src: &mut [f32], dst: &mut [T]) -> Result<(), CmsError> {
386        let l_tbl = Cube::new(&self.clut, self.grid_size as usize);
387
388        // If PCS is LAB then linear interpolation should be used
389        if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
390            return self.to_output(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z));
391        }
392
393        match self.interpolation_method {
394            #[cfg(feature = "options")]
395            InterpolationMethod::Tetrahedral => {
396                self.to_output(src, dst, |x, y, z| l_tbl.tetra_vec3(x, y, z))
397            }
398            #[cfg(feature = "options")]
399            InterpolationMethod::Pyramid => {
400                self.to_output(src, dst, |x, y, z| l_tbl.pyramid_vec3(x, y, z))
401            }
402            #[cfg(feature = "options")]
403            InterpolationMethod::Prism => {
404                self.to_output(src, dst, |x, y, z| l_tbl.prism_vec3(x, y, z))
405            }
406            InterpolationMethod::Linear => {
407                self.to_output(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z))
408            }
409        }
410    }
411}
412
413pub(crate) fn create_lut3x3(
414    lut: &LutDataType,
415    src: &[f32],
416    options: TransformOptions,
417    pcs: DataColorSpace,
418) -> Result<Vec<f32>, CmsError> {
419    if lut.num_input_channels != 3 || lut.num_output_channels != 3 {
420        return Err(CmsError::UnsupportedProfileConnection);
421    }
422
423    let mut dest = vec![0.; src.len()];
424
425    let lut_stage = stage_lut_3x3(lut, options, pcs)?;
426    lut_stage.transform(src, &mut dest)?;
427    Ok(dest)
428}