moxcms/conversions/katana/
md_pipeline.rs

1/*
2 * // Copyright (c) Radzivon Bartoshyk 6/2025. All rights reserved.
3 * //
4 * // Redistribution and use in source and binary forms, with or without modification,
5 * // are permitted provided that the following conditions are met:
6 * //
7 * // 1.  Redistributions of source code must retain the above copyright notice, this
8 * // list of conditions and the following disclaimer.
9 * //
10 * // 2.  Redistributions in binary form must reproduce the above copyright notice,
11 * // this list of conditions and the following disclaimer in the documentation
12 * // and/or other materials provided with the distribution.
13 * //
14 * // 3.  Neither the name of the copyright holder nor the names of its
15 * // contributors may be used to endorse or promote products derived from
16 * // this software without specific prior written permission.
17 * //
18 * // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 * // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 * // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 * // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 * // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 * // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 * // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 * // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29use crate::conversions::katana::md_nx3::interpolate_out_function;
30use crate::conversions::katana::{KatanaFinalStage, KatanaInitialStage};
31use crate::conversions::md_lut::{MultidimensionalLut, tetra_3i_to_any_vec};
32use crate::profile::LutDataType;
33use crate::safe_math::{SafeMul, SafePowi};
34use crate::trc::lut_interp_linear_float;
35use crate::{
36    CmsError, DataColorSpace, Layout, MalformedSize, PointeeSizeExpressible, TransformOptions,
37};
38use num_traits::AsPrimitive;
39use std::array::from_fn;
40use std::marker::PhantomData;
41
42#[derive(Default)]
43struct KatanaLutNx3<T> {
44    linearization: Vec<Vec<f32>>,
45    clut: Vec<f32>,
46    grid_size: u8,
47    input_inks: usize,
48    output: [Vec<f32>; 3],
49    _phantom: PhantomData<T>,
50    bit_depth: usize,
51}
52
53struct KatanaLut3xN<T> {
54    linearization: [Vec<f32>; 3],
55    clut: Vec<f32>,
56    grid_size: u8,
57    output_inks: usize,
58    output: Vec<Vec<f32>>,
59    dst_layout: Layout,
60    target_color_space: DataColorSpace,
61    _phantom: PhantomData<T>,
62    bit_depth: usize,
63}
64
65impl<T: Copy + PointeeSizeExpressible + AsPrimitive<f32>> KatanaLutNx3<T> {
66    fn to_pcs_impl(&self, input: &[T]) -> Result<Vec<f32>, CmsError> {
67        if input.len() % self.input_inks != 0 {
68            return Err(CmsError::LaneMultipleOfChannels);
69        }
70        let norm_value = if T::FINITE {
71            1.0 / ((1u32 << self.bit_depth) - 1) as f32
72        } else {
73            1.0
74        };
75
76        let grid_sizes: [u8; 16] = from_fn(|i| {
77            if i < self.input_inks {
78                self.grid_size
79            } else {
80                0
81            }
82        });
83
84        let md_lut = MultidimensionalLut::new(grid_sizes, self.input_inks, 3);
85
86        let layout = Layout::from_inks(self.input_inks);
87
88        let mut inks = vec![0.; self.input_inks];
89
90        let mut dst = vec![0.; (input.len() / layout.channels()) * 3];
91
92        let fetcher = interpolate_out_function(layout);
93
94        for (dest, src) in dst
95            .chunks_exact_mut(3)
96            .zip(input.chunks_exact(layout.channels()))
97        {
98            for ((ink, src_ink), curve) in inks.iter_mut().zip(src).zip(self.linearization.iter()) {
99                *ink = lut_interp_linear_float(src_ink.as_() * norm_value, curve);
100            }
101
102            let clut = fetcher(&md_lut, &self.clut, &inks);
103
104            let pcs_x = lut_interp_linear_float(clut.v[0], &self.output[0]);
105            let pcs_y = lut_interp_linear_float(clut.v[1], &self.output[1]);
106            let pcs_z = lut_interp_linear_float(clut.v[2], &self.output[2]);
107
108            dest[0] = pcs_x;
109            dest[1] = pcs_y;
110            dest[2] = pcs_z;
111        }
112        Ok(dst)
113    }
114}
115
116impl<T: Copy + PointeeSizeExpressible + AsPrimitive<f32>> KatanaInitialStage<f32, T>
117    for KatanaLutNx3<T>
118{
119    fn to_pcs(&self, input: &[T]) -> Result<Vec<f32>, CmsError> {
120        if input.len() % self.input_inks != 0 {
121            return Err(CmsError::LaneMultipleOfChannels);
122        }
123
124        self.to_pcs_impl(input)
125    }
126}
127
128impl<T: Copy + PointeeSizeExpressible + AsPrimitive<f32>> KatanaFinalStage<f32, T>
129    for KatanaLut3xN<T>
130where
131    f32: AsPrimitive<T>,
132{
133    fn to_output(&self, src: &mut [f32], dst: &mut [T]) -> Result<(), CmsError> {
134        if src.len() % 3 != 0 {
135            return Err(CmsError::LaneMultipleOfChannels);
136        }
137
138        let grid_sizes: [u8; 16] = from_fn(|i| {
139            if i < self.output_inks {
140                self.grid_size
141            } else {
142                0
143            }
144        });
145
146        let md_lut = MultidimensionalLut::new(grid_sizes, 3, self.output_inks);
147
148        let scale_value = if T::FINITE {
149            ((1u32 << self.bit_depth) - 1) as f32
150        } else {
151            1.0
152        };
153
154        let mut working = vec![0.; self.output_inks];
155
156        for (dest, src) in dst
157            .chunks_exact_mut(self.dst_layout.channels())
158            .zip(src.chunks_exact(3))
159        {
160            let x = lut_interp_linear_float(src[0], &self.linearization[0]);
161            let y = lut_interp_linear_float(src[1], &self.linearization[1]);
162            let z = lut_interp_linear_float(src[2], &self.linearization[2]);
163
164            tetra_3i_to_any_vec(&md_lut, &self.clut, x, y, z, &mut working, self.output_inks);
165
166            for (ink, curve) in working.iter_mut().zip(self.output.iter()) {
167                *ink = lut_interp_linear_float(*ink, curve);
168            }
169
170            if T::FINITE {
171                for (dst, ink) in dest.iter_mut().zip(working.iter()) {
172                    *dst = (*ink * scale_value).round().max(0.).min(scale_value).as_();
173                }
174            } else {
175                for (dst, ink) in dest.iter_mut().zip(working.iter()) {
176                    *dst = (*ink * scale_value).as_();
177                }
178            }
179        }
180
181        if self.dst_layout == Layout::Rgba && self.target_color_space == DataColorSpace::Rgb {
182            for dst in dst.chunks_exact_mut(self.dst_layout.channels()) {
183                dst[3] = scale_value.as_();
184            }
185        }
186
187        Ok(())
188    }
189}
190
191fn katana_make_lut_nx3<T: Copy + PointeeSizeExpressible + AsPrimitive<f32>>(
192    inks: usize,
193    lut: &LutDataType,
194    _: TransformOptions,
195    _: DataColorSpace,
196    bit_depth: usize,
197) -> Result<KatanaLutNx3<T>, CmsError> {
198    if inks != lut.num_input_channels as usize {
199        return Err(CmsError::UnsupportedProfileConnection);
200    }
201    if lut.num_output_channels != 3 {
202        return Err(CmsError::UnsupportedProfileConnection);
203    }
204    let clut_length: usize = (lut.num_clut_grid_points as usize)
205        .safe_powi(lut.num_input_channels as u32)?
206        .safe_mul(lut.num_output_channels as usize)?;
207
208    let clut_table = lut.clut_table.to_clut_f32();
209    if clut_table.len() != clut_length {
210        return Err(CmsError::MalformedClut(MalformedSize {
211            size: clut_table.len(),
212            expected: clut_length,
213        }));
214    }
215
216    let linearization_table = lut.input_table.to_clut_f32();
217
218    if linearization_table.len() < lut.num_input_table_entries as usize * inks {
219        return Err(CmsError::MalformedCurveLutTable(MalformedSize {
220            size: linearization_table.len(),
221            expected: lut.num_input_table_entries as usize * inks,
222        }));
223    }
224
225    let linearization = (0..inks)
226        .map(|x| {
227            linearization_table[x * lut.num_input_table_entries as usize
228                ..(x + 1) * lut.num_input_table_entries as usize]
229                .to_vec()
230        })
231        .collect::<_>();
232
233    let gamma_table = lut.output_table.to_clut_f32();
234
235    if gamma_table.len() < lut.num_output_table_entries as usize * 3 {
236        return Err(CmsError::MalformedCurveLutTable(MalformedSize {
237            size: gamma_table.len(),
238            expected: lut.num_output_table_entries as usize * 3,
239        }));
240    }
241
242    let gamma_curve0 = gamma_table[..lut.num_output_table_entries as usize].to_vec();
243    let gamma_curve1 = gamma_table
244        [lut.num_output_table_entries as usize..lut.num_output_table_entries as usize * 2]
245        .to_vec();
246    let gamma_curve2 = gamma_table
247        [lut.num_output_table_entries as usize * 2..lut.num_output_table_entries as usize * 3]
248        .to_vec();
249
250    let transform = KatanaLutNx3::<T> {
251        linearization,
252        clut: clut_table,
253        grid_size: lut.num_clut_grid_points,
254        output: [gamma_curve0, gamma_curve1, gamma_curve2],
255        input_inks: inks,
256        _phantom: PhantomData,
257        bit_depth,
258    };
259    Ok(transform)
260}
261
262fn katana_make_lut_3xn<T: Copy + PointeeSizeExpressible + AsPrimitive<f32>>(
263    inks: usize,
264    dst_layout: Layout,
265    lut: &LutDataType,
266    _: TransformOptions,
267    target_color_space: DataColorSpace,
268    bit_depth: usize,
269) -> Result<KatanaLut3xN<T>, CmsError> {
270    if lut.num_input_channels as usize != 3 {
271        return Err(CmsError::UnsupportedProfileConnection);
272    }
273    if target_color_space == DataColorSpace::Rgb {
274        if lut.num_output_channels != 3 || lut.num_output_channels != 4 {
275            return Err(CmsError::InvalidInksCountForProfile);
276        }
277        if dst_layout != Layout::Rgb || dst_layout != Layout::Rgba {
278            return Err(CmsError::InvalidInksCountForProfile);
279        }
280    } else if lut.num_output_channels as usize != dst_layout.channels() {
281        return Err(CmsError::InvalidInksCountForProfile);
282    }
283    let clut_length: usize = (lut.num_clut_grid_points as usize)
284        .safe_powi(lut.num_input_channels as u32)?
285        .safe_mul(lut.num_output_channels as usize)?;
286
287    let clut_table = lut.clut_table.to_clut_f32();
288    if clut_table.len() != clut_length {
289        return Err(CmsError::MalformedClut(MalformedSize {
290            size: clut_table.len(),
291            expected: clut_length,
292        }));
293    }
294
295    let linearization_table = lut.input_table.to_clut_f32();
296
297    if linearization_table.len() < lut.num_input_table_entries as usize * 3 {
298        return Err(CmsError::MalformedCurveLutTable(MalformedSize {
299            size: linearization_table.len(),
300            expected: lut.num_input_table_entries as usize * 3,
301        }));
302    }
303
304    let linear_curve0 = linearization_table[..lut.num_input_table_entries as usize].to_vec();
305    let linear_curve1 = linearization_table
306        [lut.num_input_table_entries as usize..lut.num_input_table_entries as usize * 2]
307        .to_vec();
308    let linear_curve2 = linearization_table
309        [lut.num_input_table_entries as usize * 2..lut.num_input_table_entries as usize * 3]
310        .to_vec();
311
312    let gamma_table = lut.output_table.to_clut_f32();
313
314    if gamma_table.len() < lut.num_output_table_entries as usize * inks {
315        return Err(CmsError::MalformedCurveLutTable(MalformedSize {
316            size: gamma_table.len(),
317            expected: lut.num_output_table_entries as usize * inks,
318        }));
319    }
320
321    let gamma = (0..inks)
322        .map(|x| {
323            gamma_table[x * lut.num_output_table_entries as usize
324                ..(x + 1) * lut.num_output_table_entries as usize]
325                .to_vec()
326        })
327        .collect::<_>();
328
329    let transform = KatanaLut3xN::<T> {
330        linearization: [linear_curve0, linear_curve1, linear_curve2],
331        clut: clut_table,
332        grid_size: lut.num_clut_grid_points,
333        output: gamma,
334        output_inks: inks,
335        _phantom: PhantomData,
336        target_color_space,
337        dst_layout,
338        bit_depth,
339    };
340    Ok(transform)
341}
342
343pub(crate) fn katana_input_make_lut_nx3<
344    T: Copy + PointeeSizeExpressible + AsPrimitive<f32> + Send + Sync,
345>(
346    src_layout: Layout,
347    inks: usize,
348    lut: &LutDataType,
349    options: TransformOptions,
350    pcs: DataColorSpace,
351    bit_depth: usize,
352) -> Result<Box<dyn KatanaInitialStage<f32, T> + Send + Sync>, CmsError> {
353    if pcs == DataColorSpace::Rgb {
354        if lut.num_input_channels != 3 {
355            return Err(CmsError::InvalidAtoBLut);
356        }
357        if src_layout != Layout::Rgba && src_layout != Layout::Rgb {
358            return Err(CmsError::InvalidInksCountForProfile);
359        }
360    } else if lut.num_input_channels != src_layout.channels() as u8 {
361        return Err(CmsError::InvalidInksCountForProfile);
362    }
363    let z0 = katana_make_lut_nx3::<T>(inks, lut, options, pcs, bit_depth)?;
364    Ok(Box::new(z0))
365}
366
367pub(crate) fn katana_output_make_lut_3xn<
368    T: Copy + PointeeSizeExpressible + AsPrimitive<f32> + Send + Sync,
369>(
370    dst_layout: Layout,
371    lut: &LutDataType,
372    options: TransformOptions,
373    target_color_space: DataColorSpace,
374    bit_depth: usize,
375) -> Result<Box<dyn KatanaFinalStage<f32, T> + Send + Sync>, CmsError>
376where
377    f32: AsPrimitive<T>,
378{
379    let real_inks = if target_color_space == DataColorSpace::Rgb {
380        3
381    } else {
382        dst_layout.channels()
383    };
384    let z0 = katana_make_lut_3xn::<T>(
385        real_inks,
386        dst_layout,
387        lut,
388        options,
389        target_color_space,
390        bit_depth,
391    )?;
392    Ok(Box::new(z0))
393}