Skip to content

Commit

Permalink
fix: shift right overflow in ACIR with unknown var now returns zero (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
asterite authored Feb 28, 2025
1 parent ebaff44 commit ca21820
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 0 deletions.
48 changes: 48 additions & 0 deletions compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ impl Context<'_> {
let lhs_typ = self.function.dfg.type_of_value(lhs).unwrap_numeric();
let base = self.field_constant(FieldElement::from(2_u128));
let pow = self.pow(base, rhs);
let pow = self.pow_or_max_for_bit_size(pow, rhs, bit_size, lhs_typ);
let pow = self.insert_cast(pow, lhs_typ);
if lhs_typ.is_unsigned() {
// unsigned right bit shift is just a normal division
Expand Down Expand Up @@ -205,6 +206,53 @@ impl Context<'_> {
}
}

/// Returns `pow` or the maximum value allowed for `typ` if 2^rhs is guaranteed to exceed that maximum.
fn pow_or_max_for_bit_size(
&mut self,
pow: ValueId,
rhs: ValueId,
bit_size: u32,
typ: NumericType,
) -> ValueId {
let max = if typ.is_unsigned() {
if bit_size == 128 { u128::MAX } else { (1_u128 << bit_size) - 1 }
} else {
1_u128 << (bit_size - 1)
};
let max = self.field_constant(FieldElement::from(max));

// Here we check whether rhs is less than the bit_size: if it's not then it will overflow.
// Then we do:
//
// rhs_is_less_than_bit_size = lt rhs, bit_size
// rhs_is_not_less_than_bit_size = not rhs_is_less_than_bit_size
// pow_when_is_less_than_bit_size = rhs_is_less_than_bit_size * pow
// pow_when_is_not_less_than_bit_size = rhs_is_not_less_than_bit_size * max
// pow = add pow_when_is_less_than_bit_size, pow_when_is_not_less_than_bit_size
//
// All operations here are unchecked because they work on field types.
let rhs_typ = self.function.dfg.type_of_value(rhs).unwrap_numeric();
let bit_size = self.numeric_constant(bit_size as u128, rhs_typ);
let rhs_is_less_than_bit_size = self.insert_binary(rhs, BinaryOp::Lt, bit_size);
let rhs_is_not_less_than_bit_size = self.insert_not(rhs_is_less_than_bit_size);
let rhs_is_less_than_bit_size =
self.insert_cast(rhs_is_less_than_bit_size, NumericType::NativeField);
let rhs_is_not_less_than_bit_size =
self.insert_cast(rhs_is_not_less_than_bit_size, NumericType::NativeField);
let pow_when_is_less_than_bit_size =
self.insert_binary(rhs_is_less_than_bit_size, BinaryOp::Mul { unchecked: true }, pow);
let pow_when_is_not_less_than_bit_size = self.insert_binary(
rhs_is_not_less_than_bit_size,
BinaryOp::Mul { unchecked: true },
max,
);
self.insert_binary(
pow_when_is_less_than_bit_size,
BinaryOp::Add { unchecked: true },
pow_when_is_not_less_than_bit_size,
)
}

/// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs
/// Pseudo-code of the computation:
/// let mut r = 1;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
name = "shift_right_overflow"
type = "bin"
authors = [""]
[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = 9
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fn main(x: u8) {
// This would previously overflow in ACIR. Now it returns zero.
let value = 1 >> x;
assert_eq(value, 0);
}

0 comments on commit ca21820

Please sign in to comment.