From 36c38e58bcb48fda3c8d3ca25ce529ed216c08e8 Mon Sep 17 00:00:00 2001 From: Halil Durak Date: Wed, 26 Feb 2025 01:08:17 +0300 Subject: [PATCH] docs: document the AVX2 instructions (#203) --- src/simd/avx2.rs | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/src/simd/avx2.rs b/src/simd/avx2.rs index 078c365..367d312 100644 --- a/src/simd/avx2.rs +++ b/src/simd/avx2.rs @@ -21,6 +21,8 @@ pub unsafe fn match_uri_vectored(bytes: &mut Bytes) { #[allow(non_snake_case, overflowing_literals)] #[allow(unused)] unsafe fn match_url_char_32_avx(buf: &[u8]) -> usize { + // NOTE: This check might be not necessary since this function is only used in + // `match_uri_vectored` where buffer overflow is taken care of. debug_assert!(buf.len() >= 32); #[cfg(target_arch = "x86")] @@ -28,18 +30,53 @@ unsafe fn match_url_char_32_avx(buf: &[u8]) -> usize { #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; + // pointer to buffer let ptr = buf.as_ptr(); // %x21-%x7e %x80-%xff + // + // Character ranges allowed by this function, can also be interpreted as: + // 33 =< (x != 127) =< 255 + // + // Create a vector full of DEL (0x7f) characters. let DEL: __m256i = _mm256_set1_epi8(0x7f); + // Create a vector full of exclamation mark (!) (0x21) characters. + // Used as lower threshold, characters in URLs cannot be smaller than this. let LOW: __m256i = _mm256_set1_epi8(0x21); + // Load a chunk of 32 bytes from `ptr` as a vector. + // We can check 32 bytes in parallel at most with AVX2 since + // YMM registers can only have 256 bits most. let dat = _mm256_lddqu_si256(ptr as *const _); + // unsigned comparison dat >= LOW + // + // We create a new via `_mm256_max_epu8` which compares vectors `dat` and `LOW` + // and picks the max. values from each for all indices. + // So if a byte in `dat` is <= 32, it'll be represented as 33 + // which is the smallest valid character. + // + // Then, we compare the new vector with `dat` for equality. + // + // `_mm256_cmpeq_epi8` returns a new vector where; + // * matching bytes are set to 0xFF (all bits set), + // * nonmatching bytes are set to 0 (no bits set). let low = _mm256_cmpeq_epi8(_mm256_max_epu8(dat, LOW), dat); + // Similar to what we did before, but now invalid characters are set to 0xFF. let del = _mm256_cmpeq_epi8(dat, DEL); + + // We glue the both comparisons via `_mm256_andnot_si256`. + // + // Since the representation of truthy/falsy differ in these comparisons, + // we cannot use + // we are in need of bitwise NOT to convert valid characters of `del`. let bit = _mm256_andnot_si256(del, low); + // This creates a bitmask from the most significant bit of each byte. + // Simply, we're converting a vector value to scalar value here. let res = _mm256_movemask_epi8(bit) as u32; + + // Count trailing zeros to find the first encountered invalid character. + // Bitwise NOT is required once again to flip truthiness. // TODO: use .trailing_ones() once MSRV >= 1.46 (!res).trailing_zeros() as usize } @@ -72,17 +109,38 @@ unsafe fn match_header_value_char_32_avx(buf: &[u8]) -> usize { let ptr = buf.as_ptr(); // %x09 %x20-%x7e %x80-%xff + // Create a vector full of horizontal tab (\t) (0x09) characters. let TAB: __m256i = _mm256_set1_epi8(0x09); + // Create a vector full of DEL (0x7f) characters. let DEL: __m256i = _mm256_set1_epi8(0x7f); + // Create a vector full of space (0x20) characters. let LOW: __m256i = _mm256_set1_epi8(0x20); + // Load a chunk of 32 bytes from `ptr` as a vector. let dat = _mm256_lddqu_si256(ptr as *const _); + // unsigned comparison dat >= LOW + // + // Same as what we do in `match_url_char_32_avx`. + // This time the lower threshold is set to space character though. let low = _mm256_cmpeq_epi8(_mm256_max_epu8(dat, LOW), dat); + // Check if `dat` includes `TAB` characters. let tab = _mm256_cmpeq_epi8(dat, TAB); + // Check if `dat` includes `DEL` characters. let del = _mm256_cmpeq_epi8(dat, DEL); + + // Combine all comparisons together, notice that we're also using OR + // to connect `low` and `tab` but flip bits of `del`. + // + // In the end, this is simply: + // ~del & (low | tab) let bit = _mm256_andnot_si256(del, _mm256_or_si256(low, tab)); + // This creates a bitmask from the most significant bit of each byte. + // Creates a scalar value from vector value. let res = _mm256_movemask_epi8(bit) as u32; + + // Count trailing zeros to find the first encountered invalid character. + // Bitwise NOT is required once again to flip truthiness. // TODO: use .trailing_ones() once MSRV >= 1.46 (!res).trailing_zeros() as usize }