Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SMTLIB field-update support to AIR #1319

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion source/air/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub enum Relation {
PiecewiseLinearOrder,
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum BinaryOp {
Implies,
Eq,
Expand Down Expand Up @@ -92,6 +92,7 @@ pub enum BinaryOp {
LShr,
Shl,
BitConcat,
FieldUpdate(Ident),
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
Expand Down
9 changes: 5 additions & 4 deletions source/air/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pub(crate) struct ClosureTermX {
// The function declarations live in scope outside the expression scope, so
// we need to insert them into the typing's outer scope:
fn insert_fun_typing(ctxt: &mut Context, x: &Ident, typs: &Typs, typ: &Typ) {
let fun = DeclaredX::Fun(typs.clone(), typ.clone());
let fun = DeclaredX::Fun { params: typs.clone(), ret: typ.clone(), field_accessor: false };

// the maps that aren't ctxt.typing.decls (e.g. apply_map) are still in the outer scope,
// so use one of them as the outer scope index:
Expand Down Expand Up @@ -521,7 +521,7 @@ fn simplify_expr(ctxt: &mut Context, state: &mut State, expr: &Expr) -> (Typ, Ex
}
ExprX::Apply(x, args) => {
let typ = match ctxt.typing.get(x) {
Some(DeclaredX::Fun(_, typ)) => typ.clone(),
Some(DeclaredX::Fun { params: _, ret, field_accessor: _ }) => ret.clone(),
_ => panic!("internal error: missing function {}", x),
};
let (es, ts) = simplify_exprs(ctxt, state, &**args);
Expand Down Expand Up @@ -584,9 +584,10 @@ fn simplify_expr(ctxt: &mut Context, state: &mut State, expr: &Expr) -> (Typ, Ex
(TypX::BitVec(n1), TypX::BitVec(n2)) => Arc::new(TypX::BitVec(n1 + n2)),
_ => panic!("internal error during processing concat"),
},
BinaryOp::FieldUpdate(_) => ts[0].0.clone(),
};
let (es, t) = enclose(state, App::Binary(*op), es, ts);
(typ, Arc::new(ExprX::Binary(*op, es[0].clone(), es[1].clone())), t)
let (es, t) = enclose(state, App::Binary(op.clone()), es, ts);
(typ, Arc::new(ExprX::Binary(op.clone(), es[0].clone(), es[1].clone())), t)
}
ExprX::Multi(op, es) => {
let typ = match op {
Expand Down
16 changes: 16 additions & 0 deletions source/air/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ fn relation_binary_op(n1: &Node, n2: &Node) -> Option<BinaryOp> {
}
}

fn field_update_binary_op(n1: &Node, n2: &Node) -> Option<BinaryOp> {
match (n1, n2) {
(Node::Atom(s1), Node::Atom(s2)) if s1.as_str() == "update-field" => {
Some(BinaryOp::FieldUpdate(Arc::new(s2.clone())))
}
_ => None,
}
}

fn map_nodes_to_vec<A, F>(nodes: &[Node], f: &F) -> Result<Arc<Vec<A>>, String>
where
F: Fn(&Node) -> Result<A, String>,
Expand Down Expand Up @@ -301,6 +310,13 @@ impl Parser {
{
relation_binary_op(&nodes[1], &nodes[2])
}
Node::List(nodes)
if nodes.len() == 3
&& nodes[0] == Node::Atom("_".to_string())
&& field_update_binary_op(&nodes[1], &nodes[2]).is_some() =>
{
field_update_binary_op(&nodes[1], &nodes[2])
}
_ => None,
};
let lop = match &nodes[0] {
Expand Down
63 changes: 34 additions & 29 deletions source/air/src/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,37 +194,42 @@ impl Printer {
}
ExprX::Binary(op, lhs, rhs) => {
let sop = match op {
BinaryOp::Implies => "=>",
BinaryOp::Eq => "=",
BinaryOp::Le => "<=",
BinaryOp::Ge => ">=",
BinaryOp::Lt => "<",
BinaryOp::Gt => ">",
BinaryOp::EuclideanDiv => "div",
BinaryOp::EuclideanMod => "mod",
BinaryOp::Implies => str_to_node("=>"),
BinaryOp::Eq => str_to_node("="),
BinaryOp::Le => str_to_node("<="),
BinaryOp::Ge => str_to_node(">="),
BinaryOp::Lt => str_to_node("<"),
BinaryOp::Gt => str_to_node(">"),
BinaryOp::EuclideanDiv => str_to_node("div"),
BinaryOp::EuclideanMod => str_to_node("mod"),
BinaryOp::Relation(..) => unreachable!(),
BinaryOp::BitXor => "bvxor",
BinaryOp::BitAnd => "bvand",
BinaryOp::BitOr => "bvor",
BinaryOp::BitAdd => "bvadd",
BinaryOp::BitSub => "bvsub",
BinaryOp::BitMul => "bvmul",
BinaryOp::BitUDiv => "bvudiv",
BinaryOp::BitUMod => "bvurem",
BinaryOp::BitULt => "bvult",
BinaryOp::BitUGt => "bvugt",
BinaryOp::BitULe => "bvule",
BinaryOp::BitUGe => "bvuge",
BinaryOp::BitSLt => "bvslt",
BinaryOp::BitSGt => "bvsgt",
BinaryOp::BitSLe => "bvsle",
BinaryOp::BitSGe => "bvsge",
BinaryOp::LShr => "bvlshr",
BinaryOp::AShr => "bvashr",
BinaryOp::Shl => "bvshl",
BinaryOp::BitConcat => "concat",
BinaryOp::BitXor => str_to_node("bvxor"),
BinaryOp::BitAnd => str_to_node("bvand"),
BinaryOp::BitOr => str_to_node("bvor"),
BinaryOp::BitAdd => str_to_node("bvadd"),
BinaryOp::BitSub => str_to_node("bvsub"),
BinaryOp::BitMul => str_to_node("bvmul"),
BinaryOp::BitUDiv => str_to_node("bvudiv"),
BinaryOp::BitUMod => str_to_node("bvurem"),
BinaryOp::BitULt => str_to_node("bvult"),
BinaryOp::BitUGt => str_to_node("bvugt"),
BinaryOp::BitULe => str_to_node("bvule"),
BinaryOp::BitUGe => str_to_node("bvuge"),
BinaryOp::BitSLt => str_to_node("bvslt"),
BinaryOp::BitSGt => str_to_node("bvsgt"),
BinaryOp::BitSLe => str_to_node("bvsle"),
BinaryOp::BitSGe => str_to_node("bvsge"),
BinaryOp::LShr => str_to_node("bvlshr"),
BinaryOp::AShr => str_to_node("bvashr"),
BinaryOp::Shl => str_to_node("bvshl"),
BinaryOp::BitConcat => str_to_node("concat"),
BinaryOp::FieldUpdate(field_ident) => Node::List(vec![
str_to_node("_"),
str_to_node("update-field"),
str_to_node(&**field_ident),
]),
};
Node::List(vec![str_to_node(sop), self.expr_to_node(lhs), self.expr_to_node(rhs)])
Node::List(vec![sop, self.expr_to_node(lhs), self.expr_to_node(rhs)])
}
ExprX::Multi(op, exprs) => {
let sop = match op {
Expand Down
2 changes: 1 addition & 1 deletion source/air/src/smt_verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn label_asserts<'ctx>(
// asserts are on rhs of =>
// (slight hack to also allow rhs of == for quantified function definitions)
Arc::new(ExprX::Binary(
*op,
op.clone(),
lhs.clone(),
label_asserts(context, infos, axiom_infos, rhs),
))
Expand Down
147 changes: 147 additions & 0 deletions source/air/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2102,3 +2102,150 @@ fn no_partial_order() {
)
)
}

#[test]
fn datatype_field_update_pass() {
yes!(
(declare-datatypes ((A 0)) (((A_A (A_A_u Int)))))
(check-valid
(declare-var a A)
(block
(assign a ((_ update-field A_A_u) a 3))
(assert (= (A_A_u a) 3))
)
)
)
}

#[test]
fn datatype_field_update_ill_typed() {
untyped!(
(declare-datatypes ((X 0)) (((X_X (X_X_u Int)))))
(declare-datatypes ((A 0)) (((A_A (A_A_u Int)))))
(check-valid
(declare-var a A)
(declare-const x X)
(block
(assign a ((_ update-field A_A_u) a x))
(assert (= (A_A_u a) 3))
)
)
)
}

#[test]
fn datatype_field_update2() {
no!(
(declare-datatypes ((A 0)) (((A_A (A_A_u Int)))))
(check-valid
(declare-var a A)
(block
(assign a ((_ update-field A_A_u) a 3))
(assert (= (A_A_u a) 4))
)
)
)
}

#[test]
fn datatype_field_update3() {
yes!(
(declare-datatypes ((A 0)) (((A_A (A_A_u Int) (A_A_v Int)))))
(check-valid
(declare-var a A)
(block
(assign a ((_ update-field A_A_u) a 3))
(assert (= (A_A_u a) 3))
)
)
)
}

#[test]
fn datatype_field_update4() {
no!(
(declare-datatypes ((A 0)) (((A_A (A_A_u Int) (A_A_v Int)))))
(check-valid
(declare-var a A)
(block
(assign a ((_ update-field A_A_u) a 3))
(assert (= (A_A_u a) 4))
)
)
)
}

#[test]
fn nested_datatype_field_update_pass() {
yes!(
(declare-datatypes ((A 0)) (((A_A (A_A_u Int)))))
(declare-datatypes ((B 0)) (((B_B (B_B_a A)))))
(check-valid
(declare-var b B)
(block
(assign b ((_ update-field B_B_a) b ((_ update-field A_A_u) (B_B_a b) 3)))
(assert (= (A_A_u (B_B_a b)) 3))
)
)
)
}

#[test]
fn nested_datatype_field_update_pass2() {
yes!(
(declare-datatypes ((A 0)) (((A_A (A_A_u Int) (A_A_v Int)))))
(declare-datatypes ((B 0)) (((B_B (B_B_a1 A) (B_B_a2 A)))))
(check-valid
(declare-var b B)
(block
(assign b ((_ update-field B_B_a1) b ((_ update-field A_A_u) (B_B_a1 b) 3)))
(assert (= (A_A_u (B_B_a1 b)) 3))
)
)
)
}

#[test]
fn nested_datatype_field_update_fail() {
no!(
(declare-datatypes ((A 0)) (((A_A (A_A_u Int)))))
(declare-datatypes ((B 0)) (((B_B (B_B_a A)))))
(check-valid
(declare-var b B)
(block
(assign b ((_ update-field B_B_a) b ((_ update-field A_A_u) (B_B_a b) 3)))
(assert (= (A_A_u (B_B_a b)) 4))
)
)
)
}

#[test]
fn accessor_identifying_1() {
untyped!(
(declare-datatypes ((A 0)) (((A_A (A_A_u Int)))))
(declare-fun f (A) Int )
(check-valid
(declare-var a A)
(block
(assign a ((_ update-field f) a 3))
(assert (= (A_A_u a) 4))
)
)
)
}

#[test]
fn accessor_identifying_2() {
untyped!(
(declare-datatypes ((A 0)) (((A_A (A_A_u Int)))))
(declare-datatypes ((B 0)) (((B_B (B_B_u Int)))))
(check-valid
(declare-var a A)
(block
(assign a ((_ update-field B_B_u) a 3))
(assert (= (A_A_u a) 4))
)
)
)
}
Loading