Skip to content

Commit

Permalink
Proofs
Browse files Browse the repository at this point in the history
  • Loading branch information
LasseBlaauwbroek committed Apr 29, 2024
1 parent f4927b0 commit 8ae836d
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 89 deletions.
4 changes: 4 additions & 0 deletions pytact/data_reader.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,10 @@ cdef class OnlineDefinitionsReader:
"""
return Definition._group_by_clusters(self.definitions(full))

def node_by_id(self, nodeid: NodeId) -> Node:
"""Lookup a node inside of this reader by it's local node-id. This is a low-level function."""
return Node.init(self.graph_index.nodes.size() - 1, nodeid, &self.graph_index)

@contextmanager
def online_definitions_initialize(OnlineDefinitionsReader stack,
GlobalContextAddition_Reader init) -> Generator[OnlineDefinitionsReader, None, None]:
Expand Down
6 changes: 6 additions & 0 deletions pytact/fake_python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TacticPredictionGraph, TacticPredictionsGraph,
TacticPredictionText, TacticPredictionsText,
GlobalContextMessage, CheckAlignmentMessage, CheckAlignmentResponse)
from pytact.visualisation_webserver import wrap_visualization

async def text_prediction_loop(context : GlobalContextMessage):
tactics = [ 'idtac "is it working?"', 'idtac "yes it is working!"', 'auto' ]
Expand Down Expand Up @@ -69,8 +70,11 @@ async def graph_prediction_loop(context : GlobalContextMessage, level):
raise Exception(f"Capnp protocol error {msg}")



async def run_session(args, record_file, capnp_stream):
messages_generator = capnp_message_generator(capnp_stream, args.rpc, record_file)
if args.with_visualization:
messages_generator = await wrap_visualization(messages_generator)
if args.mode == 'text':
print('Python server running in text mode')
await text_prediction_loop(messages_generator)
Expand Down Expand Up @@ -102,6 +106,8 @@ async def server():
'replayed through "pytact-fake-coq"')
parser.add_argument('--rpc', action='store_true', default = False,
help='Communicate through Cap\'n Proto RPC.')
parser.add_argument('--with-visualization', action='store_true', default = False,
help='Launch a visualization webserver')
args = parser.parse_args()

if args.record_file is not None:
Expand Down
27 changes: 19 additions & 8 deletions pytact/graph_visualize_browse.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ class GraphVisualizationData:
graphid2path: List[Path] = field(init=False)

def __post_init__(self):
self.trans_deps = transitive_closure({d.filename: list(d.dependencies)
for d in self.data.values()})
self.graphid2path = [d.filename for d in sorted(self.data.values(), key=lambda d: d.graph)]
if len(self.data.values()) == 0: return
if hasattr(list(self.data.values())[0], "dependencies"):
self.trans_deps = transitive_closure({d.filename: list(d.dependencies)
for d in self.data.values()})
self.graphid2path = [d.filename for d in sorted(self.data.values(), key=lambda d: d.graph)]
else:
self.trans_deps = { p : set() for p in self.data.keys()}
self.graphid2path = list(self.data.keys())

@dataclass
class GraphVisualizationOutput:
Expand Down Expand Up @@ -127,6 +132,12 @@ def render_proof_state_text(ps: ProofState):
'<br>----------------------<br>' + ps.conclusion_text +
'<br><br>Raw: ' + ps.text)

def mn(dataset):
if hasattr(dataset, "module_name"):
return dataset.module_name
else:
return ""

class GraphVisualizator:
def __init__(self, data: GraphVisualizationData, url_maker: UrlMaker, settings: Settings = Settings()):
self.data = data.data
Expand Down Expand Up @@ -191,7 +202,7 @@ def global_context(self, fname: Path):

dataset = self.data[fname]
representative = dataset.representative
module_name = dataset.module_name
module_name = mn(dataset)

def render_def(dot2, d: Definition):
label = make_label(module_name, d.name)
Expand Down Expand Up @@ -221,7 +232,7 @@ def render_def(dot2, d: Definition):
dot.edge(id, id2,
arrowtail="odot", dir="both", constraint="false", style="dashed")

for cluster in dataset.clustered_definitions():
for cluster in dataset.clustered_definitions(full=False):

start = str(cluster[0].node)
ltail = None
Expand Down Expand Up @@ -349,7 +360,7 @@ def definition(self, fname: Path, definition: int):
proof = [("Proof", self.url_maker.proof(fname, definition))]
ext_location = (
location +
[(make_label(self.data[fname].module_name, label),
[(make_label(mn(self.data[fname]), label),
self.url_maker.definition(fname, definition))] +
proof)
return GraphVisualizationOutput(dot.source, ext_location, len(location), text)
Expand Down Expand Up @@ -405,7 +416,7 @@ def proof(self, fname: Path, definition: int):
dot.edge(before_id, qedid)

location = (self.path2location(fname) +
[(make_label(self.data[fname].module_name, d.name),
[(make_label(mn(self.data[fname]), d.name),
self.url_maker.definition(fname, definition)),
("Proof", self.url_maker.proof(fname, definition))])
return GraphVisualizationOutput(dot.source, location, len(location) - 1)
Expand Down Expand Up @@ -508,7 +519,7 @@ def nlm(node: Node):
dot2.edge('artificial-root', id)

location = (self.path2location(fname) +
[(make_label(self.data[fname].module_name, d.name),
[(make_label(mn(self.data[fname]), d.name),
self.url_maker.definition(fname, definition)),
("Proof", self.url_maker.proof(fname, definition)),
(f"Step {stepi} outcome {outcomei}",
Expand Down
49 changes: 47 additions & 2 deletions pytact/visualisation_webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
except ImportError:
import importlib.resources as ilr

from pytact.data_reader import data_reader
from pytact.data_reader import data_reader, GlobalContextMessage, ProofState, CheckAlignmentMessage
from pytact.graph_visualize_browse import (
GraphVisualizationData, GraphVisualizator, UrlMaker, Settings, GraphVisualizationOutput)

Expand All @@ -36,7 +36,10 @@ def create_app(dataset_path: Path) -> Sanic:
context_manager = ExitStack()
template_path = ilr.files('pytact') / 'templates/'
app.config.TEMPLATING_PATH_TO_TEMPLATES = context_manager.enter_context(ilr.as_file(template_path))
app.ctx.gvd = GraphVisualizationData(context_manager.enter_context(data_reader(dataset_path)))
if isinstance(dataset_path, Path):
app.ctx.gvd = GraphVisualizationData(context_manager.enter_context(data_reader(dataset_path)))
else:
app.ctx.gvd = dataset_path

@app.after_server_stop
async def teardown(app):
Expand Down Expand Up @@ -110,6 +113,48 @@ async def root_folder(request, query: Settings):

return app


async def wrap_visualization(context : GlobalContextMessage) -> GlobalContextMessage:
app = create_app(GraphVisualizationData(dict()))

server = await app.create_server(
port=8000, host="0.0.0.0", return_asyncio_server=True
)

await server.startup()
await server.start_serving()

async def wrapper(context, stack):
data = { Path(f"Slice{i}.bin") : d for i, d in enumerate(stack)}
app.ctx.gvd = GraphVisualizationData(data)
prediction_requests = context.prediction_requests
async for msg in prediction_requests:
# Redirect any exceptions to Coq. Additionally, deal with CancellationError
# thrown when a request from Coq is cancelled
async with context.redirect_exceptions(Exception):
if isinstance(msg, ProofState):
resp = yield msg
yield
await prediction_requests.asend(resp)
elif isinstance(msg, CheckAlignmentMessage):
resp = yield msg
yield
await prediction_requests.asend(resp)
elif isinstance(msg, GlobalContextMessage):
yield GlobalContextMessage(msg.definitions,
msg.tactics,
msg.log_annotation,
wrapper(msg, stack + [msg.definitions]),
msg.redirect_exceptions)
else:
raise Exception(f"Capnp protocol error {msg}")

return GlobalContextMessage(context.definitions,
context.tactics,
context.log_annotation,
wrapper(context, []),
context.redirect_exceptions)

def main():

parser = argparse.ArgumentParser(
Expand Down
Loading

0 comments on commit 8ae836d

Please sign in to comment.