Skip to content

Commit

Permalink
Abstract proof hints writer (#1118)
Browse files Browse the repository at this point in the history
This PR refactors the process of outputting proof hints to file so that
it is handled by an abstract `proof_trace_writer`. We then implement a
subclass of the `proof_trace_writer` class, namely
`proof_trace_file_writer`, that outputs the proof hints to a binary
file.

We do that so that we can easily add other methods of outputting proof
hints in the future, such as a shared memory / ringbuffer based writer
that will work in tandem with the shared memory proof hints parser that
was recently merged.

The PR can be easily reviewed commit by commit.
- the first 3 commits do some preparatory refactoring to normalize
naming conventions that will make the big refactoring more
straightforward.
- the 4th commit replaces the `writer *` argument of proof trace output
functions with a `FILE *` argument, since this is the only part of the
`writer` class that is used when it comes to hint generation.
- the 5th commit refactors all `FILE *` pointers related to hint
generation into pointers to a placeholder class which simply wraps said
`FILE *`. It also does some function renaming to better reflect that
these functions participate in the hint generation output process.
- the 6th commit replaces the placeholder class with the abstract
`proof_trace_writer` class and adds a concrete subclass for writing into
a file, namely `proof_trace_file_writer`.
  • Loading branch information
theo25 authored Jul 29, 2024
1 parent 8e55383 commit 84d791a
Show file tree
Hide file tree
Showing 12 changed files with 316 additions and 223 deletions.
12 changes: 6 additions & 6 deletions config/llvm_header.inc
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,11 @@ define i1 @string_equal(ptr %str1, ptr %str2, i64 %len1, i64 %len2) {
declare tailcc ptr @k_step(ptr)
declare tailcc void @step_all(ptr)
declare void @serialize_configuration_to_file_v2(ptr, ptr)
declare void @write_uint64_to_file(ptr, i64)
declare void @serialize_configuration_to_proof_writer(ptr, ptr)
declare void @write_uint64_to_proof_trace(ptr, i64)
@proof_output = external global i1
@output_file = external global ptr
@proof_writer = external global ptr
@depth = thread_local global i64 zeroinitializer
@steps = thread_local global i64 zeroinitializer
@current_interval = thread_local global i64 0
Expand Down Expand Up @@ -236,9 +236,9 @@ define ptr @take_steps(i64 %depth, ptr %subject) {
%proof_output = load i1, ptr @proof_output
br i1 %proof_output, label %if, label %merge
if:
%output_file = load ptr, ptr @output_file
call void @write_uint64_to_file(ptr %output_file, i64 18446744073709551615)
call void @serialize_configuration_to_file_v2(ptr %output_file, ptr %subject)
%proof_writer = load ptr, ptr @proof_writer
call void @write_uint64_to_proof_trace(ptr %proof_writer, i64 18446744073709551615)
call void @serialize_configuration_to_proof_writer(ptr %proof_writer, ptr %subject)
br label %merge
merge:
store i64 %depth, ptr @depth
Expand Down
36 changes: 36 additions & 0 deletions include/kllvm/binary/serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,42 @@ void serializer::emit(T val) {

void emit_kore_rich_header(std::ostream &os, kore_definition *definition);

class proof_trace_writer {
public:
virtual ~proof_trace_writer() = default;
virtual void write(void const *ptr, size_t len) = 0;

virtual void write_string(char const *str, size_t len) = 0;

// Note: This method will not write a 0 at the end of string.
// The passed string should be 0 terminated.
virtual void write_string(char const *str) = 0;

// Note: this method will write a 0 at the end of the string.
// The passed string should be 0 terminated.
void write_null_terminated_string(char const *str) {
write_string(str);
char n = 0;
write(&n, 1);
}

void write_bool(bool b) { write(&b, sizeof(bool)); }
void write_uint32(uint32_t i) { write(&i, sizeof(uint32_t)); }
void write_uint64(uint64_t i) { write(&i, sizeof(uint64_t)); }
};

class proof_trace_file_writer : public proof_trace_writer {
private:
FILE *file_;

public:
proof_trace_file_writer(FILE *file)
: file_(file) { }
void write(void const *ptr, size_t len) override;
void write_string(char const *str, size_t len) override;
void write_string(char const *str) override;
};

} // namespace kllvm

#endif
29 changes: 15 additions & 14 deletions include/kllvm/codegen/ProofEvent.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,40 +37,41 @@ class proof_event {
* proof output and continuation, then loading the output filename from its
* global.
*
* Returns a triple [proof enabled, merge, output_file]; see `proofBranch` and
* `emitGetOutputFileName`.
* Returns a triple [proof enabled, merge, proof_writer]; see `proofBranch`
* and `emitGetOutputFileName`.
*/
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
event_prelude(std::string const &label, llvm::BasicBlock *insert_at_end);

/*
* Emit a call that will serialize `term` to the specified `outputFile` as
* Emit a call that will serialize `term` to the specified `proof_writer` as
* binary KORE. This function can be called on any term, but the sort of that
* term must be known.
*/
llvm::CallInst *emit_serialize_term(
kore_composite_sort &sort, llvm::Value *output_file, llvm::Value *term,
kore_composite_sort &sort, llvm::Value *proof_writer, llvm::Value *term,
llvm::BasicBlock *insert_at_end);

/*
* Emit a call that will serialize `value` to the specified `outputFile`.
* Emit a call that will serialize `value` to the specified `proof_writer`.
*/
llvm::CallInst *emit_write_uint64(
llvm::Value *output_file, uint64_t value,
llvm::Value *proof_writer, uint64_t value,
llvm::BasicBlock *insert_at_end);

/*
* Emit a call that will serialize a boolean value to the specified `output_file`.
* Emit a call that will serialize a boolean value to the specified
* `proof_writer`.
*/
llvm::CallInst *emit_bool_term(
llvm::Value *output_file, llvm::Value *term,
llvm::CallInst *emit_write_bool(
llvm::Value *proof_writer, llvm::Value *term,
llvm::BasicBlock *insert_at_end);

/*
* Emit a call that will serialize `str` to the specified `outputFile`.
* Emit a call that will serialize `str` to the specified `proof_writer`.
*/
llvm::CallInst *emit_write_string(
llvm::Value *output_file, std::string const &str,
llvm::Value *proof_writer, std::string const &str,
llvm::BasicBlock *insert_at_end);

/*
Expand All @@ -85,10 +86,10 @@ class proof_event {
llvm::BinaryOperator *emit_no_op(llvm::BasicBlock *insert_at_end);

/*
* Emit instructions to load the path of the interpreter's current output
* file; used here for binary proof trace data.
* Emit instructions to get a pointer to the interpreter's proof_trace_writer;
* the data structure that outputs proof trace data.
*/
llvm::LoadInst *emit_get_output_file_name(llvm::BasicBlock *insert_at_end);
llvm::LoadInst *emit_get_proof_trace_writer(llvm::BasicBlock *insert_at_end);

public:
[[nodiscard]] llvm::BasicBlock *hook_event_pre(
Expand Down
49 changes: 29 additions & 20 deletions include/runtime/header.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <immer/map.hpp>
#include <immer/set.hpp>
#include <kllvm/ast/AST.h>
#include <kllvm/binary/serializer.h>
#include <runtime/collections/rangemap.h>
#include <unordered_set>

Expand Down Expand Up @@ -336,19 +337,25 @@ void serialize_configurations(
void serialize_configuration(
block *subject, char const *sort, char **data_out, size_t *size_out,
bool emit_size, bool use_intern);
void serialize_configuration_v2(FILE *file, block *subject, uint32_t sort);
void serialize_configuration_to_file(
FILE *file, block *subject, bool emit_size, bool use_intern);
void serialize_configuration_to_file_v2(FILE *file, block *subject);
void write_uint64_to_file(FILE *file, uint64_t i);
void write_bool_to_file(FILE *file, bool b);
void serialize_term_to_file(
FILE *file, void *subject, char const *sort, bool use_intern,
bool k_item_inj = false);
void serialize_term_to_file_v2(FILE *file, void *subject, uint64_t, bool);
void serialize_raw_term_to_file(
FILE *file, void *subject, char const *sort, bool use_intern);
void print_variable_to_file(FILE *file, char const *varname);

// The following functions are called by the generated code and runtime code to
// ouput the proof trace data.
void serialize_configuration_to_proof_trace(
void *proof_writer, block *subject, uint32_t sort);
void serialize_configuration_to_proof_writer(
void *proof_writer, block *subject);
void write_uint64_to_proof_trace(void *proof_writer, uint64_t i);
void write_bool_to_proof_trace(void *proof_writer, bool b);
void write_string_to_proof_trace(void *proof_writer, char const *str);
void serialize_term_to_proof_trace(
void *proof_writer, void *subject, uint64_t, bool);

// The following functions have to be generated at kompile time
// and linked with the interpreter.
Expand All @@ -368,7 +375,8 @@ bool hook_STRING_eq(SortString, SortString);
char const *get_symbol_name_for_tag(uint32_t tag);
char const *get_return_sort_for_tag(uint32_t tag);
char const **get_argument_sorts_for_tag(uint32_t tag);
uint32_t *get_argument_sorts_for_tag_v2(uint32_t tag);
uint32_t *
get_argument_sorts_for_tag_with_proof_trace_serialization(uint32_t tag);
char const *top_sort(void);

bool symbol_is_instantiation(uint32_t tag);
Expand All @@ -391,17 +399,17 @@ using visitor = struct {
writer *, rangemap *, char const *, char const *, char const *, void *);
};

using serialize_visitor = struct {
void (*visit_config)(writer *, block *, uint32_t, bool);
void (*visit_map)(writer *, map *, uint32_t, uint32_t, uint32_t);
void (*visit_list)(writer *, list *, uint32_t, uint32_t, uint32_t);
void (*visit_set)(writer *, set *, uint32_t, uint32_t, uint32_t);
void (*visit_int)(writer *, mpz_t, uint32_t);
void (*visit_float)(writer *, floating *, uint32_t);
void (*visit_bool)(writer *, bool, uint32_t);
void (*visit_string_buffer)(writer *, stringbuffer *, uint32_t);
void (*visit_m_int)(writer *, size_t *, size_t, uint32_t);
void (*visit_range_map)(writer *, rangemap *, uint32_t, uint32_t, uint32_t);
using serialize_to_proof_trace_visitor = struct {
void (*visit_config)(void *, block *, uint32_t, bool);
void (*visit_map)(void *, map *, uint32_t, uint32_t, uint32_t);
void (*visit_list)(void *, list *, uint32_t, uint32_t, uint32_t);
void (*visit_set)(void *, set *, uint32_t, uint32_t, uint32_t);
void (*visit_int)(void *, mpz_t, uint32_t);
void (*visit_float)(void *, floating *, uint32_t);
void (*visit_bool)(void *, bool, uint32_t);
void (*visit_string_buffer)(void *, stringbuffer *, uint32_t);
void (*visit_m_int)(void *, size_t *, size_t, uint32_t);
void (*visit_range_map)(void *, rangemap *, uint32_t, uint32_t, uint32_t);
};

void print_map(
Expand All @@ -414,8 +422,9 @@ void print_list(
writer *, list *, char const *, char const *, char const *, void *);
void visit_children(
block *subject, writer *file, visitor *printer, void *state);
void visit_children_for_serialize(
block *subject, writer *file, serialize_visitor *printer);
void visit_children_for_serialize_to_proof_trace(
block *subject, void *proof_writer,
serialize_to_proof_trace_visitor *printer);

stringbuffer *hook_BUFFER_empty(void);
stringbuffer *hook_BUFFER_concat(stringbuffer *buf, string *s);
Expand Down
12 changes: 12 additions & 0 deletions lib/binary/serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,16 @@ void emit_kore_rich_header(std::ostream &os, kore_definition *definition) {
}
}

void proof_trace_file_writer::write(void const *ptr, size_t len) {
fwrite(ptr, len, 1, file_);
}

void proof_trace_file_writer::write_string(char const *str, size_t len) {
fwrite(str, 1, len, file_);
}

void proof_trace_file_writer::write_string(char const *str) {
fputs(str, file_);
}

} // namespace kllvm
23 changes: 13 additions & 10 deletions lib/codegen/EmitConfigParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,8 @@ static void emit_get_token(kore_definition *definition, llvm::Module *module) {

static llvm::StructType *make_packed_visitor_structure_type(
llvm::LLVMContext &ctx, llvm::Module *module, bool is_serialize) {
std::string const name = is_serialize ? "serialize_visitor" : "visitor";
std::string const name
= is_serialize ? "serialize_to_proof_trace_visitor" : "visitor";
static auto types = std::map<llvm::LLVMContext *, llvm::StructType *>{};

auto *ptr_ty = llvm::PointerType::getUnqual(ctx);
Expand Down Expand Up @@ -1077,7 +1078,7 @@ static void get_visitor(
}

// NOLINTNEXTLINE(*-cognitive-complexity)
static void get_serialize_visitor(
static void get_serialize_to_proof_trace_visitor(
kore_definition *definition, llvm::Module *module, kore_symbol *symbol,
llvm::BasicBlock *case_block, std::vector<llvm::Value *> const &callbacks) {
get_visitor(definition, module, symbol, case_block, callbacks, false);
Expand Down Expand Up @@ -1179,11 +1180,11 @@ static void emit_visit_children(kore_definition *def, llvm::Module *mod) {
emit_traversal("visit_children", def, mod, true, false, get_visitor);
}

static void
emit_visit_children_for_serialize(kore_definition *def, llvm::Module *mod) {
static void emit_visit_children_for_serialize_to_proof_trace(
kore_definition *def, llvm::Module *mod) {
emit_traversal(
"visit_children_for_serialize", def, mod, true, true,
get_serialize_visitor);
"visit_children_for_serialize_to_proof_trace", def, mod, true, true,
get_serialize_to_proof_trace_visitor);
}

static void emit_inj_tags(kore_definition *def, llvm::Module *mod) {
Expand All @@ -1205,7 +1206,8 @@ static void emit_inj_tags(kore_definition *def, llvm::Module *mod) {
}
}

static void emit_sort_table_v2(kore_definition *def, llvm::Module *mod) {
static void emit_sort_table_for_proof_trace_serialization(
kore_definition *def, llvm::Module *mod) {
auto getter = [](kore_definition *definition, llvm::Module *module,
kore_symbol *symbol) -> llvm::Constant * {
auto &ctx = module->getContext();
Expand Down Expand Up @@ -1245,7 +1247,8 @@ static void emit_sort_table_v2(kore_definition *def, llvm::Module *mod) {
auto *debug_ty = get_pointer_debug_type(get_int_debug_type(), "int *");

emit_data_table_for_symbol(
"get_argument_sorts_for_tag_v2", entry_ty, debug_ty, def, mod, getter);
"get_argument_sorts_for_tag_with_proof_trace_serialization", entry_ty,
debug_ty, def, mod, getter);
}

static void emit_sort_table(kore_definition *def, llvm::Module *mod) {
Expand Down Expand Up @@ -1373,14 +1376,14 @@ void emit_config_parser_functions(

emit_get_symbol_name_for_tag(definition, module);
emit_visit_children(definition, module);
emit_visit_children_for_serialize(definition, module);
emit_visit_children_for_serialize_to_proof_trace(definition, module);

emit_layouts(definition, module);

emit_inj_tags(definition, module);

emit_sort_table(definition, module);
emit_sort_table_v2(definition, module);
emit_sort_table_for_proof_trace_serialization(definition, module);
emit_return_sort_table(definition, module);
emit_symbol_is_instantiation(definition, module);
}
Expand Down
Loading

0 comments on commit 84d791a

Please sign in to comment.