diff --git a/src/engine/symbolic_state.rs b/src/engine/symbolic_state.rs index c855fe40..69005e23 100644 --- a/src/engine/symbolic_state.rs +++ b/src/engine/symbolic_state.rs @@ -411,6 +411,10 @@ impl<'a> Formula for FormulaView<'a> { !matches!(self.data_flow[NodeIndex::new(sym)], Symbol::Operator(_)) } + fn is_input(&self, sym: SymbolId) -> bool { + matches!(self.data_flow[NodeIndex::new(sym)], Symbol::Input(_)) + } + fn traverse(&self, n: SymbolId, visit_map: &mut HashMap, v: &mut V) -> R where V: FormulaVisitor, diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 53c91759..4d3821a4 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -155,6 +155,8 @@ pub trait Formula: Index { fn is_operand(&self, sym: SymbolId) -> bool; + fn is_input(&self, sym: SymbolId) -> bool; + fn traverse(&self, n: SymbolId, visit_map: &mut HashMap, v: &mut V) -> R where V: FormulaVisitor, diff --git a/src/solver/monster.rs b/src/solver/monster.rs index 220207a5..c8fde1a2 100644 --- a/src/solver/monster.rs +++ b/src/solver/monster.rs @@ -77,7 +77,7 @@ fn is_invertible(op: BVOperator, s: BitVector, t: BitVector, d: OperandSide) -> } }, BVOperator::Remu => match d { - OperandSide::Lhs => !(s <= t), + OperandSide::Lhs => !(-s) >= t, OperandSide::Rhs => { if s == t { true @@ -154,9 +154,9 @@ fn select( (rhs, lhs, OperandSide::Rhs) } else if is_constant(f, rhs) { (lhs, rhs, OperandSide::Lhs) - } else if is_essential(f, lhs, OperandSide::Lhs, rhs, t, ab) { + } else if is_essential(f, idx, lhs, OperandSide::Lhs, t, ab) { (lhs, rhs, OperandSide::Lhs) - } else if is_essential(f, rhs, OperandSide::Rhs, lhs, t, ab) { + } else if is_essential(f, idx, rhs, OperandSide::Rhs, t, ab) { (rhs, lhs, OperandSide::Rhs) } else if random() { (rhs, lhs, OperandSide::Rhs) @@ -176,25 +176,29 @@ fn compute_inverse_value(op: BVOperator, s: BitVector, t: BitVector, d: OperandS OperandSide::Rhs => s - t, }, BVOperator::Mul => { - let y = s >> s.ctz(); + if s == BitVector(0) { + BitVector(random::()) + } else { + let y = s >> s.ctz(); - let y_inv = y - .modinverse() - .expect("a modular inverse has to exist iff operator is invertible"); + let y_inv = y + .modinverse() + .expect("a modular inverse has to exist iff operator is invertible"); - let result = (t >> s.ctz()) * y_inv; + let result = (t >> s.ctz()) * y_inv; - let to_shift = 64 - s.ctz(); + let to_shift = 64 - s.ctz(); - let arbitrary_bit_mask = if to_shift == 64 { - BitVector(0) - } else { - BitVector::ones() << to_shift - }; + let arbitrary_bit_mask = if to_shift == 64 { + BitVector(0) + } else { + BitVector::ones() << to_shift + }; - let arbitrary_bits = BitVector(random::()) & arbitrary_bit_mask; + let arbitrary_bits = BitVector(random::()) & arbitrary_bit_mask; - result | arbitrary_bits + result | arbitrary_bits + } } BVOperator::Sltu => match d { OperandSide::Lhs => { @@ -223,34 +227,23 @@ fn compute_inverse_value(op: BVOperator, s: BitVector, t: BitVector, d: OperandS if (t == BitVector::ones()) && (s == BitVector(1)) { BitVector::ones() } else { - let range_start = t * s; - if range_start.0.overflowing_add(s.0 - 1).1 { - BitVector( - thread_rng() - .sample(Uniform::new_inclusive(range_start.0, u64::max_value())), - ) - } else { - BitVector(thread_rng().sample(Uniform::new_inclusive( - range_start.0, - range_start.0 + (s.0 - 1), - ))) - } + let range_start = (t * s).0; + let range_end = range_start.saturating_add(s.0 - 1); + + BitVector(thread_rng().sample(Uniform::new_inclusive(range_start, range_end))) } } OperandSide::Rhs => { - if (t == s) && t == BitVector::ones() { - BitVector(thread_rng().sample(Uniform::new_inclusive(0, 1))) - } else if (t == BitVector::ones()) && (s != BitVector::ones()) { - BitVector(0) - } else { - s / t - } + let range_start = s / (t + BitVector(1)) + BitVector(1); + let range_end = s / t; + + BitVector(thread_rng().sample(Uniform::new_inclusive(range_start.0, range_end.0))) } }, BVOperator::Remu => match d { OperandSide::Lhs => { let y = BitVector( - thread_rng().sample(Uniform::new_inclusive(1, ((BitVector::ones() - t) / s).0)), + thread_rng().sample(Uniform::new_inclusive(0, ((BitVector::ones() - t) / s).0)), ); // below computation cannot overflow due to how `y` was chosen assert!( @@ -305,7 +298,7 @@ fn compute_consistent_value(op: BVOperator, t: BitVector, d: OperandSide) -> Bit BVOperator::Add | BVOperator::Sub | BVOperator::Equals => BitVector(random::()), BVOperator::Mul => BitVector({ if t == BitVector(0) { - 0 + random::() } else { let mut r; loop { @@ -323,17 +316,19 @@ fn compute_consistent_value(op: BVOperator, t: BitVector, d: OperandSide) -> Bit }), BVOperator::Divu => match d { OperandSide::Lhs => { - if (t == BitVector::ones()) || (t == BitVector(0)) { + if t == BitVector::ones() { + BitVector(random::()) + } else if t == BitVector(0) { BitVector(thread_rng().sample(Uniform::new_inclusive(0, u64::max_value() - 1))) } else { - let mut y = BitVector(0); - while !(y != BitVector(0)) && !(y.mulo(t)) { - y = BitVector( - thread_rng().sample(Uniform::new_inclusive(0, u64::max_value())), - ); - } + let y = BitVector( + thread_rng().sample(Uniform::new_inclusive(1, u64::max_value() / t.0)), + ); + + let range_start = (t * y).0; + let range_end = range_start.saturating_add(y.0 - 1); - y * t + BitVector(thread_rng().sample(Uniform::new_inclusive(range_start, range_end))) } } OperandSide::Rhs => { @@ -368,10 +363,16 @@ fn compute_consistent_value(op: BVOperator, t: BitVector, d: OperandSide) -> Bit }, BVOperator::Remu => match d { OperandSide::Lhs => { - if t == BitVector::ones() { - BitVector::ones() + if t == BitVector::ones() || t > BitVector::ones() - t { + t } else { - BitVector(thread_rng().sample(Uniform::new_inclusive(t.0, BitVector::ones().0))) + let r = thread_rng().sample(Uniform::new_inclusive(2 * t.0, u64::max_value())); + + if r == 2 * t.0 { + t + } else { + BitVector(r) + } } } OperandSide::Rhs => { @@ -440,15 +441,15 @@ fn value( fn is_essential( formula: &F, + n: SymbolId, this: SymbolId, on_side: OperandSide, - other: SymbolId, t: BitVector, ab: &[BitVector], ) -> bool { let ab_nx = ab[this]; - match &formula[other] { + match &formula[n] { Symbol::Operator(op) => !is_invertible(*op, ab_nx, t, on_side.other()), // TODO: not mentioned in paper => improvised. is that really true? Symbol::Constant(_) | Symbol::Input(_) => false, @@ -637,7 +638,9 @@ fn sat( n = nx; } - update_assignment(formula, &mut ab, n, t); + if formula.is_input(n) { + update_assignment(formula, &mut ab, n, t); + } } let assignment: Assignment = formula.symbol_ids().map(|i| (i, ab[i])).collect(); @@ -821,42 +824,64 @@ mod tests { // prove: Ey.(computed <> y == t) where <> is the binary bit vector operator // + // compute inverse value for other operand let inverse = match op { - BVOperator::Add => t - computed, + BVOperator::Add => { + assert!( + is_invertible(op, computed, t, d.other()), + "consistent value has an inverse" + ); + compute_inverse_value(op, computed, t, d.other()) + } BVOperator::Mul => { assert!( - is_invertible(op, computed, t, d), - "choose values which are invertible..." + is_invertible(op, computed, t, d.other()), + "consistent value has an inverse" ); - - compute_inverse_value(op, computed, t, d) + compute_inverse_value(op, computed, t, d.other()) + } + BVOperator::Sltu => { + assert!( + is_invertible(op, computed, t, d.other()), + "consistent value has an inverse" + ); + compute_inverse_value(op, computed, t, d.other()) } - BVOperator::Sltu => compute_inverse_value(op, computed, t, d), BVOperator::Divu => { - assert!(is_invertible(op, computed, t, d)); - compute_inverse_value(op, computed, t, d) + assert!( + is_invertible(op, computed, t, d.other()), + "consistent value has an inverse" + ); + compute_inverse_value(op, computed, t, d.other()) + } + BVOperator::Remu => { + assert!( + is_invertible(op, computed, t, d.other()), + "consistent value has an inverse" + ); + compute_inverse_value(op, computed, t, d.other()) } _ => unimplemented!(), }; if d == OperandSide::Lhs { assert_eq!( - f(inverse, computed), + f(computed, inverse), t, "{:?} {:?} {:?} == {:?}", - inverse, - op, computed, + op, + inverse, t ); } else { assert_eq!( - f(computed, inverse), + f(inverse, computed), t, "{:?} {:?} {:?} == {:?}", - computed, - op, inverse, + op, + computed, t ); } @@ -890,6 +915,15 @@ mod tests { let side = OperandSide::Lhs; test_invertibility(MUL, 0b1, 0b1, side, true, "trivial multiplication"); + test_invertibility( + MUL, + 0b0, + 0b0, + side, + true, + "trivial multiplication, t == s == 0", + ); + test_invertibility(MUL, 0b10, 0b0, side, true, "trivial multiplication, t == 0"); test_invertibility(MUL, 0b10, 0b1, side, false, "operand bigger than result"); test_invertibility( MUL, @@ -989,6 +1023,8 @@ mod tests { test_inverse_value_computation(MUL, 0b10, 0b10, side, f); test_inverse_value_computation(MUL, 0b100, 0b100, side, f); test_inverse_value_computation(MUL, 0b10, 0b1100, side, f); + test_inverse_value_computation(MUL, 0b0, 0b0, side, f); + test_inverse_value_computation(MUL, 0b10, 0b0, side, f); } #[test] @@ -1101,4 +1137,23 @@ mod tests { test_consistent_value_computation(SLTU, 0, side, f); test_consistent_value_computation(SLTU, 1, side, f); } + + #[test] + fn compute_consistent_values_for_remu() { + let mut side = OperandSide::Lhs; + + fn f(l: BitVector, r: BitVector) -> BitVector { + l % r + } + + // test only for values which actually have a consistent value + test_consistent_value_computation(REMU, u64::max_value(), side, f); + test_consistent_value_computation(REMU, u64::max_value() / 2 + 7, side, f); + test_consistent_value_computation(REMU, 0, side, f); + test_consistent_value_computation(REMU, 7, side, f); + + side = OperandSide::Rhs; + test_consistent_value_computation(REMU, u64::max_value(), side, f); + test_consistent_value_computation(REMU, 7, side, f); + } }