1use crate::conversions::katana::{KatanaFinalStage, KatanaInitialStage};
30use crate::err::MalformedSize;
31use crate::profile::LutDataType;
32use crate::safe_math::{SafeMul, SafePowi};
33use crate::trc::lut_interp_linear_float;
34use crate::{
35 CmsError, Cube, DataColorSpace, InterpolationMethod, PointeeSizeExpressible, Stage,
36 TransformOptions, Vector3f,
37};
38use num_traits::AsPrimitive;
39
40#[derive(Default)]
41struct Lut3x3 {
42 input: [Vec<f32>; 3],
43 clut: Vec<f32>,
44 grid_size: u8,
45 gamma: [Vec<f32>; 3],
46 interpolation_method: InterpolationMethod,
47 pcs: DataColorSpace,
48}
49
50#[derive(Default)]
51struct KatanaLut3x3<T: Copy + Default> {
52 input: [Vec<f32>; 3],
53 clut: Vec<f32>,
54 grid_size: u8,
55 gamma: [Vec<f32>; 3],
56 interpolation_method: InterpolationMethod,
57 pcs: DataColorSpace,
58 _phantom: std::marker::PhantomData<T>,
59 bit_depth: usize,
60}
61
62fn make_lut_3x3(
63 lut: &LutDataType,
64 options: TransformOptions,
65 pcs: DataColorSpace,
66) -> Result<Lut3x3, CmsError> {
67 let clut_length: usize = (lut.num_clut_grid_points as usize)
68 .safe_powi(lut.num_input_channels as u32)?
69 .safe_mul(lut.num_output_channels as usize)?;
70
71 let lin_table = lut.input_table.to_clut_f32();
72
73 if lin_table.len() < lut.num_input_table_entries as usize * 3 {
74 return Err(CmsError::MalformedCurveLutTable(MalformedSize {
75 size: lin_table.len(),
76 expected: lut.num_input_table_entries as usize * 3,
77 }));
78 }
79
80 let lin_curve0 = lin_table[..lut.num_input_table_entries as usize].to_vec();
81 let lin_curve1 = lin_table
82 [lut.num_input_table_entries as usize..lut.num_input_table_entries as usize * 2]
83 .to_vec();
84 let lin_curve2 = lin_table
85 [lut.num_input_table_entries as usize * 2..lut.num_input_table_entries as usize * 3]
86 .to_vec();
87
88 let clut_table = lut.clut_table.to_clut_f32();
89 if clut_table.len() != clut_length {
90 return Err(CmsError::MalformedClut(MalformedSize {
91 size: clut_table.len(),
92 expected: clut_length,
93 }));
94 }
95
96 let gamma_curves = lut.output_table.to_clut_f32();
97
98 if gamma_curves.len() < lut.num_output_table_entries as usize * 3 {
99 return Err(CmsError::MalformedCurveLutTable(MalformedSize {
100 size: gamma_curves.len(),
101 expected: lut.num_output_table_entries as usize * 3,
102 }));
103 }
104
105 let gamma_curve0 = gamma_curves[..lut.num_output_table_entries as usize].to_vec();
106 let gamma_curve1 = gamma_curves
107 [lut.num_output_table_entries as usize..lut.num_output_table_entries as usize * 2]
108 .to_vec();
109 let gamma_curve2 = gamma_curves
110 [lut.num_output_table_entries as usize * 2..lut.num_output_table_entries as usize * 3]
111 .to_vec();
112
113 let transform = Lut3x3 {
114 input: [lin_curve0, lin_curve1, lin_curve2],
115 gamma: [gamma_curve0, gamma_curve1, gamma_curve2],
116 interpolation_method: options.interpolation_method,
117 clut: clut_table,
118 grid_size: lut.num_clut_grid_points,
119 pcs,
120 };
121
122 Ok(transform)
123}
124
125fn stage_lut_3x3(
126 lut: &LutDataType,
127 options: TransformOptions,
128 pcs: DataColorSpace,
129) -> Result<Box<dyn Stage>, CmsError> {
130 let lut = make_lut_3x3(lut, options, pcs)?;
131
132 let transform = Lut3x3 {
133 input: lut.input,
134 gamma: lut.gamma,
135 interpolation_method: lut.interpolation_method,
136 clut: lut.clut,
137 grid_size: lut.grid_size,
138 pcs: lut.pcs,
139 };
140
141 Ok(Box::new(transform))
142}
143
144pub(crate) fn katana_input_stage_lut_3x3<
145 T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
146>(
147 lut: &LutDataType,
148 options: TransformOptions,
149 pcs: DataColorSpace,
150 bit_depth: usize,
151) -> Result<Box<dyn KatanaInitialStage<f32, T> + Send + Sync>, CmsError>
152where
153 f32: AsPrimitive<T>,
154{
155 let lut = make_lut_3x3(lut, options, pcs)?;
156
157 let transform = KatanaLut3x3::<T> {
158 input: lut.input,
159 gamma: lut.gamma,
160 interpolation_method: lut.interpolation_method,
161 clut: lut.clut,
162 grid_size: lut.grid_size,
163 pcs: lut.pcs,
164 _phantom: std::marker::PhantomData,
165 bit_depth,
166 };
167
168 Ok(Box::new(transform))
169}
170
171pub(crate) fn katana_output_stage_lut_3x3<
172 T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
173>(
174 lut: &LutDataType,
175 options: TransformOptions,
176 pcs: DataColorSpace,
177 bit_depth: usize,
178) -> Result<Box<dyn KatanaFinalStage<f32, T> + Send + Sync>, CmsError>
179where
180 f32: AsPrimitive<T>,
181{
182 let lut = make_lut_3x3(lut, options, pcs)?;
183
184 let transform = KatanaLut3x3::<T> {
185 input: lut.input,
186 gamma: lut.gamma,
187 interpolation_method: lut.interpolation_method,
188 clut: lut.clut,
189 grid_size: lut.grid_size,
190 pcs: lut.pcs,
191 _phantom: std::marker::PhantomData,
192 bit_depth,
193 };
194
195 Ok(Box::new(transform))
196}
197
198impl Lut3x3 {
199 fn transform_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
200 &self,
201 src: &[f32],
202 dst: &mut [f32],
203 fetch: Fetch,
204 ) -> Result<(), CmsError> {
205 let linearization_0 = &self.input[0];
206 let linearization_1 = &self.input[1];
207 let linearization_2 = &self.input[2];
208 for (dest, src) in dst.chunks_exact_mut(3).zip(src.chunks_exact(3)) {
209 debug_assert!(self.grid_size as i32 >= 1);
210 let linear_x = lut_interp_linear_float(src[0], linearization_0);
211 let linear_y = lut_interp_linear_float(src[1], linearization_1);
212 let linear_z = lut_interp_linear_float(src[2], linearization_2);
213
214 let clut = fetch(linear_x, linear_y, linear_z);
215
216 let pcs_x = lut_interp_linear_float(clut.v[0], &self.gamma[0]);
217 let pcs_y = lut_interp_linear_float(clut.v[1], &self.gamma[1]);
218 let pcs_z = lut_interp_linear_float(clut.v[2], &self.gamma[2]);
219 dest[0] = pcs_x;
220 dest[1] = pcs_y;
221 dest[2] = pcs_z;
222 }
223 Ok(())
224 }
225}
226
227impl Stage for Lut3x3 {
228 fn transform(&self, src: &[f32], dst: &mut [f32]) -> Result<(), CmsError> {
229 let l_tbl = Cube::new(&self.clut, self.grid_size as usize);
230
231 if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
233 return self.transform_impl(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z));
234 }
235
236 match self.interpolation_method {
237 #[cfg(feature = "options")]
238 InterpolationMethod::Tetrahedral => {
239 self.transform_impl(src, dst, |x, y, z| l_tbl.tetra_vec3(x, y, z))?;
240 }
241 #[cfg(feature = "options")]
242 InterpolationMethod::Pyramid => {
243 self.transform_impl(src, dst, |x, y, z| l_tbl.pyramid_vec3(x, y, z))?;
244 }
245 #[cfg(feature = "options")]
246 InterpolationMethod::Prism => {
247 self.transform_impl(src, dst, |x, y, z| l_tbl.prism_vec3(x, y, z))?;
248 }
249 InterpolationMethod::Linear => {
250 self.transform_impl(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z))?;
251 }
252 }
253 Ok(())
254 }
255}
256
257impl<T: Copy + Default + PointeeSizeExpressible + AsPrimitive<f32>> KatanaLut3x3<T>
258where
259 f32: AsPrimitive<T>,
260{
261 fn to_pcs_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
262 &self,
263 input: &[T],
264 fetch: Fetch,
265 ) -> Result<Vec<f32>, CmsError> {
266 if input.len() % 3 != 0 {
267 return Err(CmsError::LaneMultipleOfChannels);
268 }
269 let normalizing_value = if T::FINITE {
270 1.0 / ((1u32 << self.bit_depth) - 1) as f32
271 } else {
272 1.0
273 };
274 let mut dst = vec![0.; input.len()];
275 let linearization_0 = &self.input[0];
276 let linearization_1 = &self.input[1];
277 let linearization_2 = &self.input[2];
278 for (dest, src) in dst.chunks_exact_mut(3).zip(input.chunks_exact(3)) {
279 let linear_x =
280 lut_interp_linear_float(src[0].as_() * normalizing_value, linearization_0);
281 let linear_y =
282 lut_interp_linear_float(src[1].as_() * normalizing_value, linearization_1);
283 let linear_z =
284 lut_interp_linear_float(src[2].as_() * normalizing_value, linearization_2);
285
286 let clut = fetch(linear_x, linear_y, linear_z);
287
288 let pcs_x = lut_interp_linear_float(clut.v[0], &self.gamma[0]);
289 let pcs_y = lut_interp_linear_float(clut.v[1], &self.gamma[1]);
290 let pcs_z = lut_interp_linear_float(clut.v[2], &self.gamma[2]);
291 dest[0] = pcs_x;
292 dest[1] = pcs_y;
293 dest[2] = pcs_z;
294 }
295 Ok(dst)
296 }
297
298 fn to_output<Fetch: Fn(f32, f32, f32) -> Vector3f>(
299 &self,
300 src: &[f32],
301 dst: &mut [T],
302 fetch: Fetch,
303 ) -> Result<(), CmsError> {
304 if src.len() % 3 != 0 {
305 return Err(CmsError::LaneMultipleOfChannels);
306 }
307 if dst.len() % 3 != 0 {
308 return Err(CmsError::LaneMultipleOfChannels);
309 }
310 if dst.len() != src.len() {
311 return Err(CmsError::LaneSizeMismatch);
312 }
313 let norm_value = if T::FINITE {
314 ((1u32 << self.bit_depth) - 1) as f32
315 } else {
316 1.0
317 };
318
319 let linearization_0 = &self.input[0];
320 let linearization_1 = &self.input[1];
321 let linearization_2 = &self.input[2];
322 for (dest, src) in dst.chunks_exact_mut(3).zip(src.chunks_exact(3)) {
323 let linear_x = lut_interp_linear_float(src[0], linearization_0);
324 let linear_y = lut_interp_linear_float(src[1], linearization_1);
325 let linear_z = lut_interp_linear_float(src[2], linearization_2);
326
327 let clut = fetch(linear_x, linear_y, linear_z);
328
329 let pcs_x = lut_interp_linear_float(clut.v[0], &self.gamma[0]);
330 let pcs_y = lut_interp_linear_float(clut.v[1], &self.gamma[1]);
331 let pcs_z = lut_interp_linear_float(clut.v[2], &self.gamma[2]);
332
333 if T::FINITE {
334 dest[0] = (pcs_x * norm_value).round().max(0.0).min(norm_value).as_();
335 dest[1] = (pcs_y * norm_value).round().max(0.0).min(norm_value).as_();
336 dest[2] = (pcs_z * norm_value).round().max(0.0).min(norm_value).as_();
337 } else {
338 dest[0] = pcs_x.as_();
339 dest[1] = pcs_y.as_();
340 dest[2] = pcs_z.as_();
341 }
342 }
343 Ok(())
344 }
345}
346
347impl<T: Copy + Default + PointeeSizeExpressible + AsPrimitive<f32>> KatanaInitialStage<f32, T>
348 for KatanaLut3x3<T>
349where
350 f32: AsPrimitive<T>,
351{
352 fn to_pcs(&self, input: &[T]) -> Result<Vec<f32>, CmsError> {
353 let l_tbl = Cube::new(&self.clut, self.grid_size as usize);
354
355 if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
357 return self.to_pcs_impl(input, |x, y, z| l_tbl.trilinear_vec3(x, y, z));
358 }
359
360 match self.interpolation_method {
361 #[cfg(feature = "options")]
362 InterpolationMethod::Tetrahedral => {
363 self.to_pcs_impl(input, |x, y, z| l_tbl.tetra_vec3(x, y, z))
364 }
365 #[cfg(feature = "options")]
366 InterpolationMethod::Pyramid => {
367 self.to_pcs_impl(input, |x, y, z| l_tbl.pyramid_vec3(x, y, z))
368 }
369 #[cfg(feature = "options")]
370 InterpolationMethod::Prism => {
371 self.to_pcs_impl(input, |x, y, z| l_tbl.prism_vec3(x, y, z))
372 }
373 InterpolationMethod::Linear => {
374 self.to_pcs_impl(input, |x, y, z| l_tbl.trilinear_vec3(x, y, z))
375 }
376 }
377 }
378}
379
380impl<T: Copy + Default + PointeeSizeExpressible + AsPrimitive<f32>> KatanaFinalStage<f32, T>
381 for KatanaLut3x3<T>
382where
383 f32: AsPrimitive<T>,
384{
385 fn to_output(&self, src: &mut [f32], dst: &mut [T]) -> Result<(), CmsError> {
386 let l_tbl = Cube::new(&self.clut, self.grid_size as usize);
387
388 if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
390 return self.to_output(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z));
391 }
392
393 match self.interpolation_method {
394 #[cfg(feature = "options")]
395 InterpolationMethod::Tetrahedral => {
396 self.to_output(src, dst, |x, y, z| l_tbl.tetra_vec3(x, y, z))
397 }
398 #[cfg(feature = "options")]
399 InterpolationMethod::Pyramid => {
400 self.to_output(src, dst, |x, y, z| l_tbl.pyramid_vec3(x, y, z))
401 }
402 #[cfg(feature = "options")]
403 InterpolationMethod::Prism => {
404 self.to_output(src, dst, |x, y, z| l_tbl.prism_vec3(x, y, z))
405 }
406 InterpolationMethod::Linear => {
407 self.to_output(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z))
408 }
409 }
410 }
411}
412
413pub(crate) fn create_lut3x3(
414 lut: &LutDataType,
415 src: &[f32],
416 options: TransformOptions,
417 pcs: DataColorSpace,
418) -> Result<Vec<f32>, CmsError> {
419 if lut.num_input_channels != 3 || lut.num_output_channels != 3 {
420 return Err(CmsError::UnsupportedProfileConnection);
421 }
422
423 let mut dest = vec![0.; src.len()];
424
425 let lut_stage = stage_lut_3x3(lut, options, pcs)?;
426 lut_stage.transform(src, &mut dest)?;
427 Ok(dest)
428}