Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds proof support for the pythonic API #99

Merged
merged 5 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions cvc5_pythonic_api/cvc5_pythonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6202,6 +6202,42 @@ def model(self):
"""
return ModelRef(self)

def proof(self):
"""Return a proof for the last `check()`.

This function raises an exception if
a proof is not available (e.g., last `check()` does not return unsat).

>>> s = Solver()
>>> s.set('produce-proofs','true')
>>> a = Int('a')
>>> s.add(a + 2 == 0)
>>> s.check()
sat
>>> try:
... s.proof()
... except RuntimeError:
... print("failed to get proof (last `check()` must have returned unsat)")
failed to get proof (last `check()` must have returned unsat)
>>> s.add(a == 0)
>>> s.check()
unsat
>>> s.proof()
(SCOPE: Not(And(a + 2 == 0, a == 0)),
(SCOPE: Not(And(a + 2 == 0, a == 0)),
[a + 2 == 0, a == 0],
(EQ_RESOLVE: False,
(ASSUME: a == 0, [a == 0]),
(MACRO_SR_EQ_INTRO: (a == 0) == False,
[a == 0, 7, 12],
(EQ_RESOLVE: a == -2,
(ASSUME: a + 2 == 0, [a + 2 == 0]),
(MACRO_SR_EQ_INTRO: (a + 2 == 0) == (a == -2),
[a + 2 == 0, 7, 12]))))))
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test for the exception, like here.

p = self.solver.getProof()[0]
return ProofRef(self, p)

def assertions(self):
"""Return an AST vector containing all added constraints.

Expand Down Expand Up @@ -6751,6 +6787,98 @@ def evaluate(t):
return m[t]


class ProofRef:
"""A proof tree where every proof reference corresponds to the
root step of a proof. The branches of the root step are the
premises of the step."""

def __init__(self, solver, proof):
self.proof = proof
self.solver = solver

def __del__(self):
if self.solver is not None:
self.solver = None

def __repr__(self):
return obj_to_string(self)

def getRule(self):
"""Returns the proof rule used by the root step of the proof.

>>> s = Solver()
>>> s.set('produce-proofs','true')
>>> a = Int('a')
>>> s.add(a + 2 == 0, a == 0)
>>> s.check()
unsat
>>> p = s.proof()
>>> p.getRule()
<ProofRule.SCOPE: 1>
"""
return self.proof.getRule()

def getResult(self):
"""Returns the conclusion of the root step of the proof.

>>> s = Solver()
>>> s.set('produce-proofs','true')
>>> a = Int('a')
>>> s.add(a + 2 == 0, a == 0)
>>> s.check()
unsat
>>> p = s.proof()
>>> p.getResult()
Not(And(a + 2 == 0, a == 0))
"""
return _to_expr_ref(self.proof.getResult(), Context(self.solver))

def getChildren(self):
"""Returns the premises, i.e., proofs themselvels, of the root step of
the proof.

>>> s = Solver()
>>> s.set('produce-proofs','true')
>>> a = Int('a')
>>> s.add(a + 2 == 0, a == 0)
>>> s.check()
unsat
>>> p = s.proof()
>>> p = p.getChildren()[0].getChildren()[0]
>>> p
(EQ_RESOLVE: False,
(ASSUME: a == 0, [a == 0]),
(MACRO_SR_EQ_INTRO: (a == 0) == False,
[a == 0, 7, 12],
(EQ_RESOLVE: a == -2,
(ASSUME: a + 2 == 0, [a + 2 == 0]),
(MACRO_SR_EQ_INTRO: (a + 2 == 0) == (a == -2),
[a + 2 == 0, 7, 12]))))
"""
children = self.proof.getChildren()
return [ProofRef(self.solver, cp) for cp in children]

def getArguments(self):
"""Returns the arguments of the root step of the proof as a list of
expressions.

>>> s = Solver()
>>> s.set('produce-proofs','true')
>>> a = Int('a')
>>> s.add(a + 2 == 0, a == 0)
>>> s.check()
unsat
>>> p = s.proof()
>>> p.getArguments()
[]
>>> p = p.getChildren()[0]
>>> p.getArguments()
[a + 2 == 0, a == 0]
"""
args = self.proof.getArguments()
return [_to_expr_ref(a, Context(self.solver)) for a in args]


def simplify(a):
"""Simplify the expression `a`.

Expand Down
23 changes: 23 additions & 0 deletions cvc5_pythonic_api/cvc5_pythonic_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,27 @@ def pp_model(self, m):
break
return seq3(r, "[", "]")

def pp_proof(self, p, d):
if d > self.max_depth:
return self.pp_ellipses()
r = []
rule = str(p.getRule())[10:]
result = p.getResult()
childrenProofs = p.getChildren()
args = p.getArguments()
result_pp = self.pp_expr(result, 0, [])
r.append(
compose(to_format("{}: ".format(rule)), indent(_len(rule) + 2, result_pp))
)
if args:
r_args = []
for arg in args:
r_args.append(self.pp_expr(arg, 0, []))
r.append(seq3(r_args, "[", "]"))
for cPf in childrenProofs:
r.append(self.pp_proof(cPf, d + 1))
return seq3(r)

def pp_func_entry(self, e):
num = e.num_args()
if num > 1:
Expand Down Expand Up @@ -1360,6 +1381,8 @@ def main(self, a):
return self.pp_seq(a.assertions(), 0, [])
elif isinstance(a, cvc.ModelRef):
return self.pp_model(a)
elif isinstance(a, cvc.ProofRef):
return self.pp_proof(a, 0)
elif isinstance(a, list) or isinstance(a, tuple):
return self.pp_list(a)
else:
Expand Down
1 change: 1 addition & 0 deletions test/pgm_outputs/proof.py.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
unsat
34 changes: 34 additions & 0 deletions test/pgms/proof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from cvc5 import ProofRule
from cvc5_pythonic_api import *

def collect_initial_assumptions(proof):
# the initial assumptions are all the arguments of the initial
# SCOPE applications in the proof
proof_assumptions = []
while (proof.getRule() == ProofRule.SCOPE):
proof_assumptions += proof.getArguments()
proof = proof.getChildren()[0]
return proof_assumptions

def validate_proof_assumptions(assertions, proof_assumptions):
# checks that the assumptions in the produced proof match the
# assertions in the problem
return sum([c in assertions for c in proof_assumptions]) == len(proof_assumptions)


p1, p2, p3 = Bools('p1 p2 p3')
x, y = Ints('x y')
s = Solver()
s.set('produce-proofs','true')
assertions = [p1, p2, p3, Implies(p1, x > 0), Implies(p2, y > x), Implies(p2, y < 1), Implies(p3, y > -3)]

for a in assertions:
s.add(a)

print(s.check())

proof = s.proof()

proof_assumptions = collect_initial_assumptions(proof)

assert validate_proof_assumptions(assertions, proof_assumptions)
Loading