Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize string parsing #1161

Merged
merged 9 commits into from
Aug 11, 2024
Merged
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@
clippy::wildcard_imports,
// things are often more readable this way
clippy::cast_lossless,
clippy::items_after_statements,
clippy::module_name_repetitions,
clippy::redundant_else,
clippy::shadow_unrelated,
Expand Down
106 changes: 68 additions & 38 deletions src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::error::{Error, ErrorCode, Result};
use alloc::vec::Vec;
use core::char;
use core::cmp;
use core::mem;
use core::ops::Deref;
use core::str;

Expand Down Expand Up @@ -221,7 +222,7 @@ where
{
loop {
let ch = tri!(next_or_eof(self));
if !ESCAPE[ch as usize] {
if !is_escape(ch, true) {
scratch.push(ch);
continue;
}
Expand Down Expand Up @@ -342,7 +343,7 @@ where
fn ignore_str(&mut self) -> Result<()> {
loop {
let ch = tri!(next_or_eof(self));
if !ESCAPE[ch as usize] {
if !is_escape(ch, true) {
continue;
}
match ch {
Expand Down Expand Up @@ -425,6 +426,65 @@ impl<'a> SliceRead<'a> {
}
}

fn skip_to_escape(&mut self, forbid_control_characters: bool) {
// Immediately bail-out on empty strings and consecutive escapes (e.g. \u041b\u0435)
if self.index == self.slice.len()
|| is_escape(self.slice[self.index], forbid_control_characters)
{
return;
}
self.index += 1;

let rest = &self.slice[self.index..];

if !forbid_control_characters {
self.index += memchr::memchr2(b'"', b'\\', rest).unwrap_or(rest.len());
return;
}

// We wish to find the first byte in range 0x00..=0x1F or " or \. Ideally, we'd use
// something akin to memchr3, but the memchr crate does not support this at the moment.
// Therefore, we use a variation on Mycroft's algorithm [1] to provide performance better
// than a naive loop. It runs faster than equivalent two-pass memchr2+SWAR code on
// benchmarks and it's cross-platform, so probably the right fit.
// [1]: https://groups.google.com/forum/#!original/comp.lang.c/2HtQXvg7iKc/xOJeipH6KLMJ
type Chunk = usize;
const STEP: usize = mem::size_of::<Chunk>();
const ONE_BYTES: Chunk = Chunk::MAX / 255; // 0x0101...01

for chunk in rest.chunks_exact(STEP) {
let chars = Chunk::from_ne_bytes(chunk.try_into().unwrap());
let contains_ctrl = chars.wrapping_sub(ONE_BYTES * 0x20) & !chars;
let chars_quote = chars ^ (ONE_BYTES * Chunk::from(b'"'));
let contains_quote = chars_quote.wrapping_sub(ONE_BYTES) & !chars_quote;
let chars_backslash = chars ^ (ONE_BYTES * Chunk::from(b'\\'));
let contains_backslash = chars_backslash.wrapping_sub(ONE_BYTES) & !chars_backslash;
let masked = (contains_ctrl | contains_quote | contains_backslash) & (ONE_BYTES << 7);
Comment on lines +457 to +462
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fun followup would be to throw this arithmetic expression into a superoptimizer and see if there is some shorter expression that it is equivalent to. It seems very likely to me that the same quantity can be computed in fewer operations than written here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some observations from https://rust.godbolt.org/z/66fGEqY6c:

LLVM already rewrites this as:

- let contains_quote = chars_quote.wrapping_sub(ONE_BYTES) & !chars_quote;
- let contains_backslash = chars_backslash.wrapping_sub(ONE_BYTES) & !chars_backslash;
+ let contains_quote = chars_quote.wrapping_sub(ONE_BYTES) & !chars;
+ let contains_backslash = chars_backslash.wrapping_sub(ONE_BYTES) & !chars;
  let masked = (contains_ctrl | contains_quote | contains_backslash) & (ONE_BYTES << 7);

because it has proven that A & !(chars ^ B) & C is equivalent to A & !chars & C whenever B & C == 0, i.e. doing the ^ only affects bits that are later erased by the second &.

Then it factored out the three & !chars into one.

- let contains_ctrl = chars.wrapping_sub(ONE_BYTES * 0x20) & !chars;
- let contains_quote = chars_quote.wrapping_sub(ONE_BYTES) & !chars;
- let contains_backslash = chars_backslash.wrapping_sub(ONE_BYTES) & !chars;
- let masked = (contains_ctrl | contains_quote | contains_backslash) & (ONE_BYTES << 7);
+ let tmp_ctrl = chars.wrapping_sub(ONE_BYTES * 0x20);
+ let tmp_quote = chars_quote.wrapping_sub(ONE_BYTES);
+ let tmp_backslash = chars_backslash.wrapping_sub(ONE_BYTES);
+ let masked = (tmp_ctrl | tmp_quote | tmp_backslash) & !chars & (ONE_BYTES << 7);

A superoptimizer would do this kind of thing, but better.

if masked != 0 {
let addresswise_first_bit = if cfg!(target_endian = "little") {
masked.trailing_zeros()
} else {
masked.leading_zeros()
};
// SAFETY: chunk is in-bounds for slice
self.index = unsafe { chunk.as_ptr().offset_from(self.slice.as_ptr()) } as usize
+ addresswise_first_bit as usize / 8;
return;
}
}

self.index += rest.len() / STEP * STEP;
self.skip_to_escape_slow();
}

#[cold]
#[inline(never)]
fn skip_to_escape_slow(&mut self) {
while self.index < self.slice.len() && !is_escape(self.slice[self.index], true) {
self.index += 1;
}
}

/// The big optimization here over IoRead is that if the string contains no
/// backslash escape sequences, the returned &str is a slice of the raw JSON
/// data so we avoid copying into the scratch space.
Expand All @@ -442,9 +502,7 @@ impl<'a> SliceRead<'a> {
let mut start = self.index;

loop {
while self.index < self.slice.len() && !ESCAPE[self.slice[self.index] as usize] {
self.index += 1;
}
self.skip_to_escape(validate);
if self.index == self.slice.len() {
return error(self, ErrorCode::EofWhileParsingString);
}
Expand All @@ -470,9 +528,7 @@ impl<'a> SliceRead<'a> {
}
_ => {
self.index += 1;
if validate {
return error(self, ErrorCode::ControlCharacterWhileParsingString);
}
return error(self, ErrorCode::ControlCharacterWhileParsingString);
}
}
}
Expand Down Expand Up @@ -538,9 +594,7 @@ impl<'a> Read<'a> for SliceRead<'a> {

fn ignore_str(&mut self) -> Result<()> {
loop {
while self.index < self.slice.len() && !ESCAPE[self.slice[self.index] as usize] {
self.index += 1;
}
self.skip_to_escape(true);
if self.index == self.slice.len() {
return error(self, ErrorCode::EofWhileParsingString);
}
Expand Down Expand Up @@ -779,33 +833,9 @@ pub trait Fused: private::Sealed {}
impl<'a> Fused for SliceRead<'a> {}
impl<'a> Fused for StrRead<'a> {}

// Lookup table of bytes that must be escaped. A value of true at index i means
// that byte i requires an escape sequence in the input.
static ESCAPE: [bool; 256] = {
const CT: bool = true; // control character \x00..=\x1F
const QU: bool = true; // quote \x22
const BS: bool = true; // backslash \x5C
const __: bool = false; // allow unescaped
[
// 1 2 3 4 5 6 7 8 9 A B C D E F
CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, // 0
CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, CT, // 1
__, __, QU, __, __, __, __, __, __, __, __, __, __, __, __, __, // 2
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 3
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4
__, __, __, __, __, __, __, __, __, __, __, __, BS, __, __, __, // 5
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F
]
};
fn is_escape(ch: u8, including_control_characters: bool) -> bool {
ch == b'"' || ch == b'\\' || (including_control_characters && ch < 0x20)
}

fn next_or_eof<'de, R>(read: &mut R) -> Result<u8>
where
Expand Down
19 changes: 19 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2497,3 +2497,22 @@ fn hash_positive_and_negative_zero() {
assert_eq!(rand.hash_one(k1), rand.hash_one(k2));
}
}

#[test]
fn test_control_character_search() {
// Different space circumstances
for n in 0..16 {
for m in 0..16 {
test_parse_err::<String>(&[(
&format!("\"{}\n{}\"", ".".repeat(n), ".".repeat(m)),
purplesyringa marked this conversation as resolved.
Show resolved Hide resolved
"control character (\\u0000-\\u001F) found while parsing a string at line 2 column 0",
)]);
}
}

// Multiple occurrences
test_parse_err::<String>(&[(
&"\"\t\n\r\"",
"control character (\\u0000-\\u001F) found while parsing a string at line 1 column 2",
)]);
}