1use super::super::alloc;
2use super::super::alloc::{Allocator, SliceWrapper, SliceWrapperMut};
3use super::find_stride;
4use super::input_pair::{InputPair, InputReference, InputReferenceMut};
5use super::interface;
6pub use super::ir_interpret::{push_base, Context, IRInterpreter};
7use super::util::{floatX, FastLog2u16};
8use super::weights::{Weights, BLEND_FIXED_POINT_PRECISION};
9use core;
10
11const DEFAULT_CM_SPEED_INDEX: usize = 8;
12const NUM_SPEEDS_TO_TRY: usize = 16;
13const SPEEDS_TO_SEARCH: [u16; NUM_SPEEDS_TO_TRY] = [
14 0, 1, 1, 1, 2, 4, 8, 16, 16, 32, 64, 128, 128, 512, 1664, 1664,
15];
16const MAXES_TO_SEARCH: [u16; NUM_SPEEDS_TO_TRY] = [
17 32, 32, 128, 16384, 1024, 1024, 8192, 48, 8192, 4096, 16384, 256, 16384, 16384, 16384, 16384,
18];
19const NIBBLE_PRIOR_SIZE: usize = 16 * NUM_SPEEDS_TO_TRY;
20const CONTEXT_MAP_PRIOR_SIZE: usize = 256 * NIBBLE_PRIOR_SIZE * 17;
22const STRIDE_PRIOR_SIZE: usize = 256 * 256 * NIBBLE_PRIOR_SIZE * 2;
23#[derive(Clone, Copy, Debug)]
24pub struct SpeedAndMax(pub u16, pub u16);
25
26pub fn speed_to_tuple(inp: [SpeedAndMax; 2]) -> [(u16, u16); 2] {
27 [(inp[0].0, inp[0].1), (inp[1].0, inp[1].1)]
28}
29
30fn get_stride_cdf_low(
31 data: &mut [u16],
32 stride_prior: u8,
33 cm_prior: usize,
34 high_nibble: u8,
35) -> &mut [u16] {
36 let index: usize =
37 1 + 2 * (cm_prior | ((stride_prior as usize & 0xf) << 8) | ((high_nibble as usize) << 12));
38 data.split_at_mut((NUM_SPEEDS_TO_TRY * index) << 4)
39 .1
40 .split_at_mut(16 * NUM_SPEEDS_TO_TRY)
41 .0
42}
43
44fn get_stride_cdf_high(data: &mut [u16], stride_prior: u8, cm_prior: usize) -> &mut [u16] {
45 let index: usize = 2 * (cm_prior | ((stride_prior as usize) << 8));
46 data.split_at_mut((NUM_SPEEDS_TO_TRY * index) << 4)
47 .1
48 .split_at_mut(16 * NUM_SPEEDS_TO_TRY)
49 .0
50}
51
52fn get_cm_cdf_low(data: &mut [u16], cm_prior: usize, high_nibble: u8) -> &mut [u16] {
53 let index: usize = (high_nibble as usize + 1) + 17 * cm_prior;
54 data.split_at_mut((NUM_SPEEDS_TO_TRY * index) << 4)
55 .1
56 .split_at_mut(16 * NUM_SPEEDS_TO_TRY)
57 .0
58}
59
60fn get_cm_cdf_high(data: &mut [u16], cm_prior: usize) -> &mut [u16] {
61 let index: usize = 17 * cm_prior;
62 data.split_at_mut((NUM_SPEEDS_TO_TRY * index) << 4)
63 .1
64 .split_at_mut(16 * NUM_SPEEDS_TO_TRY)
65 .0
66}
67fn init_cdfs(cdfs: &mut [u16]) {
68 assert_eq!(cdfs.len() % (16 * NUM_SPEEDS_TO_TRY), 0);
69 let mut total_index = 0usize;
70 let len = cdfs.len();
71 loop {
72 for cdf_index in 0..16 {
73 let vec = cdfs
74 .split_at_mut(total_index)
75 .1
76 .split_at_mut(NUM_SPEEDS_TO_TRY)
77 .0;
78 for item in vec {
79 *item = 4 + 4 * cdf_index as u16;
80 }
81 total_index += NUM_SPEEDS_TO_TRY;
82 }
83 if total_index == len {
84 break;
85 }
86 }
87}
88fn compute_combined_cost(
89 singleton_cost: &mut [floatX; NUM_SPEEDS_TO_TRY],
90 cdfs: &[u16],
91 mixing_cdf: [u16; 16],
92 nibble_u8: u8,
93 _weights: &mut [Weights; NUM_SPEEDS_TO_TRY],
94) {
95 assert_eq!(cdfs.len(), 16 * NUM_SPEEDS_TO_TRY);
96 let nibble = nibble_u8 as usize & 0xf;
97 let mut stride_pdf = [0u16; NUM_SPEEDS_TO_TRY];
98 stride_pdf.clone_from_slice(
99 cdfs.split_at(NUM_SPEEDS_TO_TRY * nibble)
100 .1
101 .split_at(NUM_SPEEDS_TO_TRY)
102 .0,
103 );
104 let mut cm_pdf: u16 = mixing_cdf[nibble];
105 if nibble_u8 != 0 {
106 let mut tmp = [0u16; NUM_SPEEDS_TO_TRY];
107 tmp.clone_from_slice(
108 cdfs.split_at(NUM_SPEEDS_TO_TRY * (nibble - 1))
109 .1
110 .split_at(NUM_SPEEDS_TO_TRY)
111 .0,
112 );
113 for i in 0..NUM_SPEEDS_TO_TRY {
114 stride_pdf[i] -= tmp[i];
115 }
116 cm_pdf -= mixing_cdf[nibble - 1]
117 }
118 let mut stride_max = [0u16; NUM_SPEEDS_TO_TRY];
119 stride_max.clone_from_slice(cdfs.split_at(NUM_SPEEDS_TO_TRY * 15).1);
120 let cm_max = mixing_cdf[15];
121 for i in 0..NUM_SPEEDS_TO_TRY {
122 if stride_pdf[i] == 0 {
123 assert_ne!(stride_pdf[i], 0);
124 }
125 if stride_max[i] == 0 {
126 assert_ne!(stride_max[i], 0);
127 }
128
129 let w = (1 << (BLEND_FIXED_POINT_PRECISION - 2)); let combined_pdf = w * u32::from(stride_pdf[i])
131 + ((1 << BLEND_FIXED_POINT_PRECISION) - w) * u32::from(cm_pdf);
132 let combined_max = w * u32::from(stride_max[i])
133 + ((1 << BLEND_FIXED_POINT_PRECISION) - w) * u32::from(cm_max);
134 let del = FastLog2u16((combined_pdf >> BLEND_FIXED_POINT_PRECISION) as u16)
135 - FastLog2u16((combined_max >> BLEND_FIXED_POINT_PRECISION) as u16);
136 singleton_cost[i] -= del;
137 }
138}
139fn compute_cost(singleton_cost: &mut [floatX; NUM_SPEEDS_TO_TRY], cdfs: &[u16], nibble_u8: u8) {
140 assert_eq!(cdfs.len(), 16 * NUM_SPEEDS_TO_TRY);
141 let nibble = nibble_u8 as usize & 0xf;
142 let mut pdf = [0u16; NUM_SPEEDS_TO_TRY];
143 pdf.clone_from_slice(
144 cdfs.split_at(NUM_SPEEDS_TO_TRY * nibble)
145 .1
146 .split_at(NUM_SPEEDS_TO_TRY)
147 .0,
148 );
149 if nibble_u8 != 0 {
150 let mut tmp = [0u16; NUM_SPEEDS_TO_TRY];
151 tmp.clone_from_slice(
152 cdfs.split_at(NUM_SPEEDS_TO_TRY * (nibble - 1))
153 .1
154 .split_at(NUM_SPEEDS_TO_TRY)
155 .0,
156 );
157 for i in 0..NUM_SPEEDS_TO_TRY {
158 pdf[i] -= tmp[i];
159 }
160 }
161 let mut max = [0u16; NUM_SPEEDS_TO_TRY];
162 max.clone_from_slice(cdfs.split_at(NUM_SPEEDS_TO_TRY * 15).1);
163 for i in 0..NUM_SPEEDS_TO_TRY {
164 if pdf[i] == 0 {
165 assert_ne!(pdf[i], 0);
166 }
167 if max[i] == 0 {
168 assert_ne!(max[i], 0);
169 }
170 let del = FastLog2u16(pdf[i]) - FastLog2u16(max[i]);
171 singleton_cost[i] -= del;
172 }
173}
174fn update_cdf(cdfs: &mut [u16], nibble_u8: u8) {
175 assert_eq!(cdfs.len(), 16 * NUM_SPEEDS_TO_TRY);
176 let mut overall_index = nibble_u8 as usize * NUM_SPEEDS_TO_TRY;
177 for _nibble in (nibble_u8 as usize & 0xf)..16 {
178 for speed_index in 0..NUM_SPEEDS_TO_TRY {
179 cdfs[overall_index + speed_index] += SPEEDS_TO_SEARCH[speed_index];
180 }
181 overall_index += NUM_SPEEDS_TO_TRY;
182 }
183 overall_index = 0;
184 for nibble in 0..16 {
185 for speed_index in 0..NUM_SPEEDS_TO_TRY {
186 if nibble == 0 {
187 assert_ne!(cdfs[overall_index + speed_index], 0);
188 } else {
189 assert_ne!(
190 cdfs[overall_index + speed_index]
191 - cdfs[overall_index + speed_index - NUM_SPEEDS_TO_TRY],
192 0
193 );
194 }
195 }
196 overall_index += NUM_SPEEDS_TO_TRY;
197 }
198 for max_index in 0..NUM_SPEEDS_TO_TRY {
199 if cdfs[15 * NUM_SPEEDS_TO_TRY + max_index] >= MAXES_TO_SEARCH[max_index] {
200 const CDF_BIAS: [u16; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
201 for nibble_index in 0..16 {
202 let tmp = &mut cdfs[nibble_index * NUM_SPEEDS_TO_TRY + max_index];
203 *tmp = (tmp.wrapping_add(CDF_BIAS[nibble_index]))
204 .wrapping_sub(tmp.wrapping_add(CDF_BIAS[nibble_index]) >> 2);
205 }
206 }
207 }
208 overall_index = 0;
209 for nibble in 0..16 {
210 for speed_index in 0..NUM_SPEEDS_TO_TRY {
211 if nibble == 0 {
212 assert_ne!(cdfs[overall_index + speed_index], 0);
213 } else {
214 assert_ne!(
215 cdfs[overall_index + speed_index]
216 - cdfs[overall_index + speed_index - NUM_SPEEDS_TO_TRY],
217 0
218 );
219 }
220 }
221 overall_index += NUM_SPEEDS_TO_TRY;
222 }
223}
224
225fn extract_single_cdf(cdf_bundle: &[u16], index: usize) -> [u16; 16] {
226 assert_eq!(cdf_bundle.len(), 16 * NUM_SPEEDS_TO_TRY);
227 assert!(index < NUM_SPEEDS_TO_TRY);
228
229 #[allow(clippy::identity_op)]
230 [
231 cdf_bundle[index + 0 * NUM_SPEEDS_TO_TRY],
232 cdf_bundle[index + 1 * NUM_SPEEDS_TO_TRY],
233 cdf_bundle[index + 2 * NUM_SPEEDS_TO_TRY],
234 cdf_bundle[index + 3 * NUM_SPEEDS_TO_TRY],
235 cdf_bundle[index + 4 * NUM_SPEEDS_TO_TRY],
236 cdf_bundle[index + 5 * NUM_SPEEDS_TO_TRY],
237 cdf_bundle[index + 6 * NUM_SPEEDS_TO_TRY],
238 cdf_bundle[index + 7 * NUM_SPEEDS_TO_TRY],
239 cdf_bundle[index + 8 * NUM_SPEEDS_TO_TRY],
240 cdf_bundle[index + 9 * NUM_SPEEDS_TO_TRY],
241 cdf_bundle[index + 10 * NUM_SPEEDS_TO_TRY],
242 cdf_bundle[index + 11 * NUM_SPEEDS_TO_TRY],
243 cdf_bundle[index + 12 * NUM_SPEEDS_TO_TRY],
244 cdf_bundle[index + 13 * NUM_SPEEDS_TO_TRY],
245 cdf_bundle[index + 14 * NUM_SPEEDS_TO_TRY],
246 cdf_bundle[index + 15 * NUM_SPEEDS_TO_TRY],
247 ]
248}
249
250fn min_cost_index_for_speed(cost: &[floatX]) -> usize {
251 assert_eq!(cost.len(), NUM_SPEEDS_TO_TRY);
252 let mut min_cost = cost[0];
253 let mut best_choice = 0;
254 for i in 1..NUM_SPEEDS_TO_TRY {
255 if cost[i] < min_cost {
256 best_choice = i;
257 min_cost = cost[i];
258 }
259 }
260 best_choice
261}
262fn min_cost_speed_max(cost: &[floatX]) -> SpeedAndMax {
263 let best_choice = min_cost_index_for_speed(cost);
264 SpeedAndMax(SPEEDS_TO_SEARCH[best_choice], MAXES_TO_SEARCH[best_choice])
265}
266
267fn min_cost_value(cost: &[floatX]) -> floatX {
268 let best_choice = min_cost_index_for_speed(cost);
269 cost[best_choice]
270}
271
272const SINGLETON_COMBINED_STRATEGY: usize = 2;
273const SINGLETON_STRIDE_STRATEGY: usize = 1;
274const SINGLETON_CM_STRATEGY: usize = 0;
275
276pub struct ContextMapEntropy<
277 'a,
278 Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX>,
279> {
280 input: InputPair<'a>,
281 context_map: interface::PredictionModeContextMap<InputReferenceMut<'a>>,
282 block_type: u8,
283 cur_stride: u8,
284 local_byte_offset: usize,
285 weight: [[Weights; NUM_SPEEDS_TO_TRY]; 2],
286
287 cm_priors: <Alloc as Allocator<u16>>::AllocatedMemory,
288 stride_priors: <Alloc as Allocator<u16>>::AllocatedMemory,
289 _stride_pyramid_leaves: [u8; find_stride::NUM_LEAF_NODES],
290 singleton_costs: [[[floatX; NUM_SPEEDS_TO_TRY]; 2]; 3],
291}
292impl<'a, Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX>>
293 ContextMapEntropy<'a, Alloc>
294{
295 pub fn new(
296 m16: &mut Alloc,
297 input: InputPair<'a>,
298 stride: [u8; find_stride::NUM_LEAF_NODES],
299 prediction_mode: interface::PredictionModeContextMap<InputReferenceMut<'a>>,
300 cdf_detection_quality: u8,
301 ) -> Self {
302 let cdf_detect = cdf_detection_quality != 0;
303 let mut ret = ContextMapEntropy::<Alloc> {
304 input,
305 context_map: prediction_mode,
306 block_type: 0,
307 cur_stride: 1,
308 local_byte_offset: 0,
309 cm_priors: if cdf_detect {
310 <Alloc as Allocator<u16>>::alloc_cell(m16, CONTEXT_MAP_PRIOR_SIZE)
311 } else {
312 <Alloc as Allocator<u16>>::AllocatedMemory::default()
313 },
314 stride_priors: if cdf_detect {
315 <Alloc as Allocator<u16>>::alloc_cell(m16, STRIDE_PRIOR_SIZE)
316 } else {
317 <Alloc as Allocator<u16>>::AllocatedMemory::default()
318 },
319 _stride_pyramid_leaves: stride,
320 weight: [
321 [Weights::new(); NUM_SPEEDS_TO_TRY],
322 [Weights::new(); NUM_SPEEDS_TO_TRY],
323 ],
324 singleton_costs: [[[0.0 as floatX; NUM_SPEEDS_TO_TRY]; 2]; 3],
325 };
326 if cdf_detect {
327 init_cdfs(ret.cm_priors.slice_mut());
328 init_cdfs(ret.stride_priors.slice_mut());
329 }
330 ret
331 }
332 pub fn take_prediction_mode(
333 &mut self,
334 ) -> interface::PredictionModeContextMap<InputReferenceMut<'a>> {
335 core::mem::replace(
336 &mut self.context_map,
337 interface::PredictionModeContextMap::<InputReferenceMut<'a>> {
338 literal_context_map: InputReferenceMut::default(),
339 predmode_speed_and_distance_context_map: InputReferenceMut::default(),
340 },
341 )
342 }
343 pub fn prediction_mode_mut(
344 &mut self,
345 ) -> &mut interface::PredictionModeContextMap<InputReferenceMut<'a>> {
346 &mut self.context_map
347 }
348 pub fn best_singleton_speeds(
349 &self,
350 cm: bool,
351 combined: bool,
352 ) -> ([SpeedAndMax; 2], [floatX; 2]) {
353 let cost_type_index = if combined {
354 2usize
355 } else if cm {
356 0usize
357 } else {
358 1
359 };
360 let mut ret_cost = [
361 self.singleton_costs[cost_type_index][0][0],
362 self.singleton_costs[cost_type_index][1][0],
363 ];
364 let mut best_indexes = [0, 0];
365 for speed_index in 1..NUM_SPEEDS_TO_TRY {
366 for highness in 0..2 {
367 let cur_cost = self.singleton_costs[cost_type_index][highness][speed_index];
368 if cur_cost < ret_cost[highness] {
369 best_indexes[highness] = speed_index;
370 ret_cost[highness] = cur_cost;
371 }
372 }
373 }
374 let ret_speed = [
375 SpeedAndMax(
376 SPEEDS_TO_SEARCH[best_indexes[0]],
377 MAXES_TO_SEARCH[best_indexes[0]],
378 ),
379 SpeedAndMax(
380 SPEEDS_TO_SEARCH[best_indexes[1]],
381 MAXES_TO_SEARCH[best_indexes[1]],
382 ),
383 ];
384 (ret_speed, ret_cost)
385 }
386 pub fn best_speeds(
387 &mut self, cm: bool,
389 combined: bool,
390 ) -> [SpeedAndMax; 2] {
391 let mut ret = [SpeedAndMax(SPEEDS_TO_SEARCH[0], MAXES_TO_SEARCH[0]); 2];
392 let cost_type_index = if combined {
393 2usize
394 } else if cm {
395 0usize
396 } else {
397 1
398 };
399 for high in 0..2 {
400 ret[high] = min_cost_speed_max(&self.singleton_costs[cost_type_index][high][..]);
405 }
406 ret
407 }
408 pub fn best_speeds_costs(
409 &mut self, cm: bool,
411 combined: bool,
412 ) -> [floatX; 2] {
413 let cost_type_index = if combined {
414 2usize
415 } else if cm {
416 0usize
417 } else {
418 1
419 };
420 let mut ret = [0.0 as floatX; 2];
421 for high in 0..2 {
422 ret[high] = min_cost_value(&self.singleton_costs[cost_type_index][high][..]);
423 }
424 ret
425 }
426 pub fn free(&mut self, alloc: &mut Alloc) {
427 <Alloc as Allocator<u16>>::free_cell(alloc, core::mem::take(&mut self.cm_priors));
428 <Alloc as Allocator<u16>>::free_cell(alloc, core::mem::take(&mut self.stride_priors));
429 }
430 fn update_cost_base(
431 &mut self,
432 stride_prior: u8,
433 _selected_bits: u8,
434 cm_prior: usize,
435 literal: u8,
436 ) {
437 let upper_nibble = (literal >> 4);
438 let lower_nibble = literal & 0xf;
439 let provisional_cm_high_cdf: [u16; 16];
440 let provisional_cm_low_cdf: [u16; 16];
441 {
442 let cm_cdf_high = get_cm_cdf_high(self.cm_priors.slice_mut(), cm_prior);
443 compute_cost(
444 &mut self.singleton_costs[SINGLETON_CM_STRATEGY][1],
445 cm_cdf_high,
446 upper_nibble,
447 );
448 let best_cm_index = DEFAULT_CM_SPEED_INDEX; provisional_cm_high_cdf = extract_single_cdf(cm_cdf_high, best_cm_index);
451 }
452 {
453 let cm_cdf_low = get_cm_cdf_low(self.cm_priors.slice_mut(), cm_prior, upper_nibble);
454 compute_cost(
455 &mut self.singleton_costs[SINGLETON_CM_STRATEGY][0],
456 cm_cdf_low,
457 lower_nibble,
458 );
459 let best_cm_index = DEFAULT_CM_SPEED_INDEX; provisional_cm_low_cdf = extract_single_cdf(cm_cdf_low, best_cm_index);
462 }
463 {
464 let stride_cdf_high =
465 get_stride_cdf_high(self.stride_priors.slice_mut(), stride_prior, cm_prior);
466 compute_combined_cost(
467 &mut self.singleton_costs[SINGLETON_COMBINED_STRATEGY][1],
468 stride_cdf_high,
469 provisional_cm_high_cdf,
470 upper_nibble,
471 &mut self.weight[1],
472 );
473 compute_cost(
474 &mut self.singleton_costs[SINGLETON_STRIDE_STRATEGY][1],
475 stride_cdf_high,
476 upper_nibble,
477 );
478 update_cdf(stride_cdf_high, upper_nibble);
479 }
480 {
481 let stride_cdf_low = get_stride_cdf_low(
482 self.stride_priors.slice_mut(),
483 stride_prior,
484 cm_prior,
485 upper_nibble,
486 );
487 compute_combined_cost(
488 &mut self.singleton_costs[SINGLETON_COMBINED_STRATEGY][0],
489 stride_cdf_low,
490 provisional_cm_low_cdf,
491 lower_nibble,
492 &mut self.weight[0],
493 );
494 compute_cost(
495 &mut self.singleton_costs[SINGLETON_STRIDE_STRATEGY][0],
496 stride_cdf_low,
497 lower_nibble,
498 );
499 update_cdf(stride_cdf_low, lower_nibble);
500 }
501 {
502 let cm_cdf_high = get_cm_cdf_high(self.cm_priors.slice_mut(), cm_prior);
503 update_cdf(cm_cdf_high, upper_nibble);
504 }
505 {
506 let cm_cdf_low = get_cm_cdf_low(self.cm_priors.slice_mut(), cm_prior, upper_nibble);
507 update_cdf(cm_cdf_low, lower_nibble);
508 }
509 }
510}
511
512impl<'a, 'b, Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX>>
513 interface::CommandProcessor<'b> for ContextMapEntropy<'a, Alloc>
514{
515 fn push(&mut self, val: interface::Command<InputReference<'b>>) {
516 push_base(self, val)
517 }
518}
519
520impl<'a, Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX>>
521 IRInterpreter for ContextMapEntropy<'a, Alloc>
522{
523 fn inc_local_byte_offset(&mut self, inc: usize) {
524 self.local_byte_offset += inc;
525 }
526 fn local_byte_offset(&self) -> usize {
527 self.local_byte_offset
528 }
529 fn update_block_type(&mut self, new_type: u8, stride: u8) {
530 self.block_type = new_type;
531 self.cur_stride = stride;
532 }
533 fn block_type(&self) -> u8 {
534 self.block_type
535 }
536 fn literal_data_at_offset(&self, index: usize) -> u8 {
537 self.input[index]
538 }
539 fn literal_context_map(&self) -> &[u8] {
540 self.context_map.literal_context_map.slice()
541 }
542 fn prediction_mode(&self) -> ::interface::LiteralPredictionModeNibble {
543 self.context_map.literal_prediction_mode()
544 }
545 fn update_cost(
546 &mut self,
547 stride_prior: [u8; 8],
548 stride_prior_offset: usize,
549 selected_bits: u8,
550 cm_prior: usize,
551 literal: u8,
552 ) {
553 let stride = self.cur_stride as usize;
554 self.update_cost_base(
555 stride_prior[stride_prior_offset.wrapping_sub(stride) & 7],
556 selected_bits,
557 cm_prior,
558 literal,
559 )
560 }
561}