memchr/vector.rs
1/// A trait for describing vector operations used by vectorized searchers.
2///
3/// The trait is highly constrained to low level vector operations needed.
4/// In general, it was invented mostly to be generic over x86's __m128i and
5/// __m256i types. At time of writing, it also supports wasm and aarch64
6/// 128-bit vector types as well.
7///
8/// # Safety
9///
10/// All methods are not safe since they are intended to be implemented using
11/// vendor intrinsics, which are also not safe. Callers must ensure that the
12/// appropriate target features are enabled in the calling function, and that
13/// the current CPU supports them. All implementations should avoid marking the
14/// routines with #[target_feature] and instead mark them as #[inline(always)]
15/// to ensure they get appropriately inlined. (inline(always) cannot be used
16/// with target_feature.)
17pub(crate) trait Vector: Copy + core::fmt::Debug {
18 /// The number of bytes in the vector. That is, this is the size of the
19 /// vector in memory.
20 const BYTES: usize;
21 /// The bits that must be zero in order for a `*const u8` pointer to be
22 /// correctly aligned to read vector values.
23 const ALIGN: usize;
24
25 /// The type of the value returned by `Vector::movemask`.
26 ///
27 /// This supports abstracting over the specific representation used in
28 /// order to accommodate different representations in different ISAs.
29 type Mask: MoveMask;
30
31 /// Create a vector with 8-bit lanes with the given byte repeated into each
32 /// lane.
33 unsafe fn splat(byte: u8) -> Self;
34
35 /// Read a vector-size number of bytes from the given pointer. The pointer
36 /// must be aligned to the size of the vector.
37 ///
38 /// # Safety
39 ///
40 /// Callers must guarantee that at least `BYTES` bytes are readable from
41 /// `data` and that `data` is aligned to a `BYTES` boundary.
42 unsafe fn load_aligned(data: *const u8) -> Self;
43
44 /// Read a vector-size number of bytes from the given pointer. The pointer
45 /// does not need to be aligned.
46 ///
47 /// # Safety
48 ///
49 /// Callers must guarantee that at least `BYTES` bytes are readable from
50 /// `data`.
51 unsafe fn load_unaligned(data: *const u8) -> Self;
52
53 /// _mm_movemask_epi8 or _mm256_movemask_epi8
54 unsafe fn movemask(self) -> Self::Mask;
55 /// _mm_cmpeq_epi8 or _mm256_cmpeq_epi8
56 unsafe fn cmpeq(self, vector2: Self) -> Self;
57 /// _mm_and_si128 or _mm256_and_si256
58 unsafe fn and(self, vector2: Self) -> Self;
59 /// _mm_or or _mm256_or_si256
60 unsafe fn or(self, vector2: Self) -> Self;
61 /// Returns true if and only if `Self::movemask` would return a mask that
62 /// contains at least one non-zero bit.
63 unsafe fn movemask_will_have_non_zero(self) -> bool {
64 self.movemask().has_non_zero()
65 }
66}
67
68/// A trait that abstracts over a vector-to-scalar operation called
69/// "move mask."
70///
71/// On x86-64, this is `_mm_movemask_epi8` for SSE2 and `_mm256_movemask_epi8`
72/// for AVX2. It takes a vector of `u8` lanes and returns a scalar where the
73/// `i`th bit is set if and only if the most significant bit in the `i`th lane
74/// of the vector is set. The simd128 ISA for wasm32 also supports this
75/// exact same operation natively.
76///
77/// ... But aarch64 doesn't. So we have to fake it with more instructions and
78/// a slightly different representation. We could do extra work to unify the
79/// representations, but then would require additional costs in the hot path
80/// for `memchr` and `packedpair`. So instead, we abstraction over the specific
81/// representation with this trait and define the operations we actually need.
82pub(crate) trait MoveMask: Copy + core::fmt::Debug {
83 /// Return a mask that is all zeros except for the least significant `n`
84 /// lanes in a corresponding vector.
85 fn all_zeros_except_least_significant(n: usize) -> Self;
86
87 /// Returns true if and only if this mask has a a non-zero bit anywhere.
88 fn has_non_zero(self) -> bool;
89
90 /// Returns the number of bits set to 1 in this mask.
91 fn count_ones(self) -> usize;
92
93 /// Does a bitwise `and` operation between `self` and `other`.
94 fn and(self, other: Self) -> Self;
95
96 /// Does a bitwise `or` operation between `self` and `other`.
97 fn or(self, other: Self) -> Self;
98
99 /// Returns a mask that is equivalent to `self` but with the least
100 /// significant 1-bit set to 0.
101 fn clear_least_significant_bit(self) -> Self;
102
103 /// Returns the offset of the first non-zero lane this mask represents.
104 fn first_offset(self) -> usize;
105
106 /// Returns the offset of the last non-zero lane this mask represents.
107 fn last_offset(self) -> usize;
108}
109
110/// This is a "sensible" movemask implementation where each bit represents
111/// whether the most significant bit is set in each corresponding lane of a
112/// vector. This is used on x86-64 and wasm, but such a mask is more expensive
113/// to get on aarch64 so we use something a little different.
114///
115/// We call this "sensible" because this is what we get using native sse/avx
116/// movemask instructions. But neon has no such native equivalent.
117#[derive(Clone, Copy, Debug)]
118pub(crate) struct SensibleMoveMask(u32);
119
120impl SensibleMoveMask {
121 /// Get the mask in a form suitable for computing offsets.
122 ///
123 /// Basically, this normalizes to little endian. On big endian, this swaps
124 /// the bytes.
125 #[inline(always)]
126 fn get_for_offset(self) -> u32 {
127 #[cfg(target_endian = "big")]
128 {
129 self.0.swap_bytes()
130 }
131 #[cfg(target_endian = "little")]
132 {
133 self.0
134 }
135 }
136}
137
138impl MoveMask for SensibleMoveMask {
139 #[inline(always)]
140 fn all_zeros_except_least_significant(n: usize) -> SensibleMoveMask {
141 debug_assert!(n < 32);
142 SensibleMoveMask(!((1 << n) - 1))
143 }
144
145 #[inline(always)]
146 fn has_non_zero(self) -> bool {
147 self.0 != 0
148 }
149
150 #[inline(always)]
151 fn count_ones(self) -> usize {
152 self.0.count_ones() as usize
153 }
154
155 #[inline(always)]
156 fn and(self, other: SensibleMoveMask) -> SensibleMoveMask {
157 SensibleMoveMask(self.0 & other.0)
158 }
159
160 #[inline(always)]
161 fn or(self, other: SensibleMoveMask) -> SensibleMoveMask {
162 SensibleMoveMask(self.0 | other.0)
163 }
164
165 #[inline(always)]
166 fn clear_least_significant_bit(self) -> SensibleMoveMask {
167 SensibleMoveMask(self.0 & (self.0 - 1))
168 }
169
170 #[inline(always)]
171 fn first_offset(self) -> usize {
172 // We are dealing with little endian here (and if we aren't, we swap
173 // the bytes so we are in practice), where the most significant byte
174 // is at a higher address. That means the least significant bit that
175 // is set corresponds to the position of our first matching byte.
176 // That position corresponds to the number of zeros after the least
177 // significant bit.
178 self.get_for_offset().trailing_zeros() as usize
179 }
180
181 #[inline(always)]
182 fn last_offset(self) -> usize {
183 // We are dealing with little endian here (and if we aren't, we swap
184 // the bytes so we are in practice), where the most significant byte is
185 // at a higher address. That means the most significant bit that is set
186 // corresponds to the position of our last matching byte. The position
187 // from the end of the mask is therefore the number of leading zeros
188 // in a 32 bit integer, and the position from the start of the mask is
189 // therefore 32 - (leading zeros) - 1.
190 32 - self.get_for_offset().leading_zeros() as usize - 1
191 }
192}
193
194#[cfg(target_arch = "x86_64")]
195mod x86sse2 {
196 use core::arch::x86_64::*;
197
198 use super::{SensibleMoveMask, Vector};
199
200 impl Vector for __m128i {
201 const BYTES: usize = 16;
202 const ALIGN: usize = Self::BYTES - 1;
203
204 type Mask = SensibleMoveMask;
205
206 #[inline(always)]
207 unsafe fn splat(byte: u8) -> __m128i {
208 _mm_set1_epi8(byte as i8)
209 }
210
211 #[inline(always)]
212 unsafe fn load_aligned(data: *const u8) -> __m128i {
213 _mm_load_si128(data as *const __m128i)
214 }
215
216 #[inline(always)]
217 unsafe fn load_unaligned(data: *const u8) -> __m128i {
218 _mm_loadu_si128(data as *const __m128i)
219 }
220
221 #[inline(always)]
222 unsafe fn movemask(self) -> SensibleMoveMask {
223 SensibleMoveMask(_mm_movemask_epi8(self) as u32)
224 }
225
226 #[inline(always)]
227 unsafe fn cmpeq(self, vector2: Self) -> __m128i {
228 _mm_cmpeq_epi8(self, vector2)
229 }
230
231 #[inline(always)]
232 unsafe fn and(self, vector2: Self) -> __m128i {
233 _mm_and_si128(self, vector2)
234 }
235
236 #[inline(always)]
237 unsafe fn or(self, vector2: Self) -> __m128i {
238 _mm_or_si128(self, vector2)
239 }
240 }
241}
242
243#[cfg(target_arch = "x86_64")]
244mod x86avx2 {
245 use core::arch::x86_64::*;
246
247 use super::{SensibleMoveMask, Vector};
248
249 impl Vector for __m256i {
250 const BYTES: usize = 32;
251 const ALIGN: usize = Self::BYTES - 1;
252
253 type Mask = SensibleMoveMask;
254
255 #[inline(always)]
256 unsafe fn splat(byte: u8) -> __m256i {
257 _mm256_set1_epi8(byte as i8)
258 }
259
260 #[inline(always)]
261 unsafe fn load_aligned(data: *const u8) -> __m256i {
262 _mm256_load_si256(data as *const __m256i)
263 }
264
265 #[inline(always)]
266 unsafe fn load_unaligned(data: *const u8) -> __m256i {
267 _mm256_loadu_si256(data as *const __m256i)
268 }
269
270 #[inline(always)]
271 unsafe fn movemask(self) -> SensibleMoveMask {
272 SensibleMoveMask(_mm256_movemask_epi8(self) as u32)
273 }
274
275 #[inline(always)]
276 unsafe fn cmpeq(self, vector2: Self) -> __m256i {
277 _mm256_cmpeq_epi8(self, vector2)
278 }
279
280 #[inline(always)]
281 unsafe fn and(self, vector2: Self) -> __m256i {
282 _mm256_and_si256(self, vector2)
283 }
284
285 #[inline(always)]
286 unsafe fn or(self, vector2: Self) -> __m256i {
287 _mm256_or_si256(self, vector2)
288 }
289 }
290}
291
292#[cfg(target_arch = "aarch64")]
293mod aarch64neon {
294 use core::arch::aarch64::*;
295
296 use super::{MoveMask, Vector};
297
298 impl Vector for uint8x16_t {
299 const BYTES: usize = 16;
300 const ALIGN: usize = Self::BYTES - 1;
301
302 type Mask = NeonMoveMask;
303
304 #[inline(always)]
305 unsafe fn splat(byte: u8) -> uint8x16_t {
306 vdupq_n_u8(byte)
307 }
308
309 #[inline(always)]
310 unsafe fn load_aligned(data: *const u8) -> uint8x16_t {
311 // I've tried `data.cast::<uint8x16_t>().read()` instead, but
312 // couldn't observe any benchmark differences.
313 Self::load_unaligned(data)
314 }
315
316 #[inline(always)]
317 unsafe fn load_unaligned(data: *const u8) -> uint8x16_t {
318 vld1q_u8(data)
319 }
320
321 #[inline(always)]
322 unsafe fn movemask(self) -> NeonMoveMask {
323 let asu16s = vreinterpretq_u16_u8(self);
324 let mask = vshrn_n_u16(asu16s, 4);
325 let asu64 = vreinterpret_u64_u8(mask);
326 let scalar64 = vget_lane_u64(asu64, 0);
327 NeonMoveMask(scalar64 & 0x8888888888888888)
328 }
329
330 #[inline(always)]
331 unsafe fn cmpeq(self, vector2: Self) -> uint8x16_t {
332 vceqq_u8(self, vector2)
333 }
334
335 #[inline(always)]
336 unsafe fn and(self, vector2: Self) -> uint8x16_t {
337 vandq_u8(self, vector2)
338 }
339
340 #[inline(always)]
341 unsafe fn or(self, vector2: Self) -> uint8x16_t {
342 vorrq_u8(self, vector2)
343 }
344
345 /// This is the only interesting implementation of this routine.
346 /// Basically, instead of doing the "shift right narrow" dance, we use
347 /// adjacent folding max to determine whether there are any non-zero
348 /// bytes in our mask. If there are, *then* we'll do the "shift right
349 /// narrow" dance. In benchmarks, this does lead to slightly better
350 /// throughput, but the win doesn't appear huge.
351 #[inline(always)]
352 unsafe fn movemask_will_have_non_zero(self) -> bool {
353 let low = vreinterpretq_u64_u8(vpmaxq_u8(self, self));
354 vgetq_lane_u64(low, 0) != 0
355 }
356 }
357
358 /// Neon doesn't have a `movemask` that works like the one in x86-64, so we
359 /// wind up using a different method[1]. The different method also produces
360 /// a mask, but 4 bits are set in the neon case instead of a single bit set
361 /// in the x86-64 case. We do an extra step to zero out 3 of the 4 bits,
362 /// but we still wind up with at least 3 zeroes between each set bit. This
363 /// generally means that we need to do some division by 4 before extracting
364 /// offsets.
365 ///
366 /// In fact, the existence of this type is the entire reason that we have
367 /// the `MoveMask` trait in the first place. This basically lets us keep
368 /// the different representations of masks without being forced to unify
369 /// them into a single representation, which could result in extra and
370 /// unnecessary work.
371 ///
372 /// [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
373 #[derive(Clone, Copy, Debug)]
374 pub(crate) struct NeonMoveMask(u64);
375
376 impl NeonMoveMask {
377 /// Get the mask in a form suitable for computing offsets.
378 ///
379 /// The mask is always already in host-endianness, so this is a no-op.
380 #[inline(always)]
381 fn get_for_offset(self) -> u64 {
382 self.0
383 }
384 }
385
386 impl MoveMask for NeonMoveMask {
387 #[inline(always)]
388 fn all_zeros_except_least_significant(n: usize) -> NeonMoveMask {
389 debug_assert!(n < 16);
390 NeonMoveMask(!(((1 << n) << 2) - 1))
391 }
392
393 #[inline(always)]
394 fn has_non_zero(self) -> bool {
395 self.0 != 0
396 }
397
398 #[inline(always)]
399 fn count_ones(self) -> usize {
400 self.0.count_ones() as usize
401 }
402
403 #[inline(always)]
404 fn and(self, other: NeonMoveMask) -> NeonMoveMask {
405 NeonMoveMask(self.0 & other.0)
406 }
407
408 #[inline(always)]
409 fn or(self, other: NeonMoveMask) -> NeonMoveMask {
410 NeonMoveMask(self.0 | other.0)
411 }
412
413 #[inline(always)]
414 fn clear_least_significant_bit(self) -> NeonMoveMask {
415 NeonMoveMask(self.0 & (self.0 - 1))
416 }
417
418 #[inline(always)]
419 fn first_offset(self) -> usize {
420 // We are dealing with little endian here (and if we aren't,
421 // we swap the bytes so we are in practice), where the most
422 // significant byte is at a higher address. That means the least
423 // significant bit that is set corresponds to the position of our
424 // first matching byte. That position corresponds to the number of
425 // zeros after the least significant bit.
426 //
427 // Note that unlike `SensibleMoveMask`, this mask has its bits
428 // spread out over 64 bits instead of 16 bits (for a 128 bit
429 // vector). Namely, where as x86-64 will turn
430 //
431 // 0x00 0xFF 0x00 0x00 0xFF
432 //
433 // into 10010, our neon approach will turn it into
434 //
435 // 10000000000010000000
436 //
437 // And this happens because neon doesn't have a native `movemask`
438 // instruction, so we kind of fake it[1]. Thus, we divide the
439 // number of trailing zeros by 4 to get the "real" offset.
440 //
441 // [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
442 (self.get_for_offset().trailing_zeros() >> 2) as usize
443 }
444
445 #[inline(always)]
446 fn last_offset(self) -> usize {
447 // See comment in `first_offset` above. This is basically the same,
448 // but coming from the other direction.
449 16 - (self.get_for_offset().leading_zeros() >> 2) as usize - 1
450 }
451 }
452}
453
454#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
455mod wasm_simd128 {
456 use core::arch::wasm32::*;
457
458 use super::{SensibleMoveMask, Vector};
459
460 impl Vector for v128 {
461 const BYTES: usize = 16;
462 const ALIGN: usize = Self::BYTES - 1;
463
464 type Mask = SensibleMoveMask;
465
466 #[inline(always)]
467 unsafe fn splat(byte: u8) -> v128 {
468 u8x16_splat(byte)
469 }
470
471 #[inline(always)]
472 unsafe fn load_aligned(data: *const u8) -> v128 {
473 *data.cast()
474 }
475
476 #[inline(always)]
477 unsafe fn load_unaligned(data: *const u8) -> v128 {
478 v128_load(data.cast())
479 }
480
481 #[inline(always)]
482 unsafe fn movemask(self) -> SensibleMoveMask {
483 SensibleMoveMask(u8x16_bitmask(self).into())
484 }
485
486 #[inline(always)]
487 unsafe fn cmpeq(self, vector2: Self) -> v128 {
488 u8x16_eq(self, vector2)
489 }
490
491 #[inline(always)]
492 unsafe fn and(self, vector2: Self) -> v128 {
493 v128_and(self, vector2)
494 }
495
496 #[inline(always)]
497 unsafe fn or(self, vector2: Self) -> v128 {
498 v128_or(self, vector2)
499 }
500 }
501}