Skip to content

Commit

Permalink
Implement serialization/deserialization of new algorithm (#1051)
Browse files Browse the repository at this point in the history
This PR implements a new version of the proof hint, version 9, which
uses the newly created serialization and deserialization functions to
serialize terms during the proof hint generation. Not only does this new
serialization format take up less memory, it also is roughly 3x faster
to serialize.

Note: there is one known issue with this serializer/deserializer: it
incorrectly specifies the second sort parameter of the `inj` symbol when
deserializing. We choose to ignore this known issue due to the fact that
this deserializer is being consumed solely by code which transforms the
resulting terms into an order-sorted representation by removing the
injections entirely.
  • Loading branch information
Dwight Guth authored May 14, 2024
1 parent 38d8a5a commit 78cf581
Show file tree
Hide file tree
Showing 18 changed files with 671 additions and 158 deletions.
10 changes: 7 additions & 3 deletions bindings/python/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,16 @@ void bind_proof_trace(py::module_ &m) {
.def_property_readonly("trace", &llvm_rewrite_trace::get_trace)
.def_static(
"parse",
[](py::bytes const &bytes) {
proof_trace_parser parser(false, false);
[](py::bytes const &bytes, kore_header const &header) {
proof_trace_parser parser(false, false, header);
auto str = std::string(bytes);
return parser.parse_proof_trace(str);
},
py::arg("bytes"));
py::arg("bytes"), py::arg("header"));

py::class_<kore_header, std::shared_ptr<kore_header>>(
proof_trace, "kore_header")
.def(py::init(&kore_header::create), py::arg("path"));
}

PYBIND11_MODULE(_kllvm, m) {
Expand Down
27 changes: 9 additions & 18 deletions include/kllvm/binary/ProofTraceParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class proof_trace_parser {
private:
bool verbose_;
bool expand_terms_;
[[maybe_unused]] kore_header const &header_;

// Caller needs to check that there are at least 8 bytes remaining in the
// stream before peeking
Expand Down Expand Up @@ -337,30 +338,19 @@ class proof_trace_parser {

template <typename It>
sptr<kore_pattern> parse_kore_term(It &ptr, It end, uint64_t &pattern_len) {
if (std::distance(ptr, end) < 11U) {
if (std::distance(ptr, end) < 9U) {
return nullptr;
}
It old_ptr = ptr;
if (detail::read<char>(ptr, end) != '\x7F'
|| detail::read<char>(ptr, end) != 'K'
|| detail::read<char>(ptr, end) != 'O'
|| detail::read<char>(ptr, end) != 'R'
|| detail::read<char>(ptr, end) != 'E') {
|| detail::read<char>(ptr, end) != '2') {
return nullptr;
}
auto version = detail::read_version(ptr, end);

if (!read_uint64(ptr, end, pattern_len)) {
return nullptr;
}

if (std::distance(ptr, end) < pattern_len) {
return nullptr;
}
if (pattern_len > 0 && std::distance(ptr, end) > pattern_len) {
end = std::next(ptr, pattern_len);
}

return detail::read(ptr, end, version);
auto result = detail::read_v2(ptr, end, header_);
pattern_len = ptr - old_ptr;
return result;
}

template <typename It>
Expand Down Expand Up @@ -736,7 +726,8 @@ class proof_trace_parser {
}

public:
proof_trace_parser(bool verbose, bool expand_terms);
proof_trace_parser(
bool verbose, bool expand_terms, kore_header const &header);

std::optional<llvm_rewrite_trace>
parse_proof_trace_from_file(std::string const &filename);
Expand Down
52 changes: 52 additions & 0 deletions include/kllvm/binary/deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,36 @@
#include <kllvm/binary/version.h>

#include <cstddef>
#include <cstdio>
#include <cstring>
#include <vector>

#include <iostream>

namespace kllvm {

class kore_header {
private:
std::vector<uint8_t> arities_;
std::vector<ptr<kore_symbol>> symbols_;

public:
kore_header(FILE *in);
static std::unique_ptr<kore_header> create(std::string const &path) {
FILE *f = fopen(path.c_str(), "rb");
auto *result = new kore_header(f);
fclose(f);
return std::unique_ptr<kore_header>(result);
}

[[nodiscard]] uint8_t get_arity(uint32_t offset) const {
return arities_[offset];
};
[[nodiscard]] kore_symbol *get_symbol(uint32_t offset) const {
return symbols_[offset].get();
};
};

namespace detail {

template <typename It>
Expand Down Expand Up @@ -249,6 +272,35 @@ sptr<kore_pattern> read(It &ptr, It end, binary_version version) {
return term_stack[0];
}

template <typename It>
sptr<kore_pattern> read_v2(It &ptr, It end, kore_header const &header) {
switch (peek(ptr)) {
case 0: {
++ptr;
auto len = detail::read<uint64_t>(ptr, end);
auto str = std::string((char *)&*ptr, (char *)(&*ptr + len));
ptr += len + 1;
return kore_string_pattern::create(str);
}
case 1: {
++ptr;
auto offset = detail::read<uint32_t>(ptr, end);
auto arity = header.get_arity(offset);
// TODO: we need to check if this PR is an `inj` symbol and adjust the
// second sort parameter of the symbol to be equal to the sort of the
// current pattern.
auto symbol = header.get_symbol(offset);
auto new_pattern = kore_composite_pattern::create(symbol);
for (auto i = 0; i < arity; ++i) {
auto child = read_v2(ptr, end, header);
new_pattern->add_argument(child);
}
return new_pattern;
}
default: throw std::runtime_error("Bad term " + std::to_string(*ptr));
}
}

} // namespace detail

std::string file_contents(std::string const &fn, int max_bytes = -1);
Expand Down
9 changes: 0 additions & 9 deletions include/kllvm/codegen/ProofEvent.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,6 @@ class proof_event {
kore_composite_sort &sort, llvm::Value *output_file, llvm::Value *term,
llvm::BasicBlock *insert_at_end);

/*
* Emit a call that will serialize `config` to the specified `outputFile` as
* binary KORE. This function does not require a sort, but the configuration
* passed must be a top-level configuration.
*/
llvm::CallInst *emit_serialize_configuration(
llvm::Value *output_file, llvm::Value *config,
llvm::BasicBlock *insert_at_end);

/*
* Emit a call that will serialize `value` to the specified `outputFile`.
*/
Expand Down
31 changes: 31 additions & 0 deletions include/runtime/header.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,16 @@ void serialize_configurations(
void serialize_configuration(
block *subject, char const *sort, char **data_out, size_t *size_out,
bool emit_size, bool use_intern);
void serialize_configuration_v2(FILE *file, block *subject, uint32_t sort);
void serialize_configuration_to_file(
FILE *file, block *subject, bool emit_size, bool use_intern);
void serialize_configuration_to_file_v2(FILE *file, block *subject);
void write_uint64_to_file(FILE *file, uint64_t i);
void write_bool_to_file(FILE *file, bool b);
void serialize_term_to_file(
FILE *file, void *subject, char const *sort, bool use_intern,
bool k_item_inj = false);
void serialize_term_to_file_v2(FILE *file, void *subject, uint64_t, bool);
void serialize_raw_term_to_file(
FILE *file, void *subject, char const *sort, bool use_intern);
void print_variable_to_file(FILE *file, char const *varname);
Expand All @@ -360,6 +363,7 @@ bool hook_STRING_eq(SortString, SortString);
char const *get_symbol_name_for_tag(uint32_t tag);
char const *get_return_sort_for_tag(uint32_t tag);
char const **get_argument_sorts_for_tag(uint32_t tag);
uint32_t *get_argument_sorts_for_tag_v2(uint32_t tag);
char const *top_sort(void);

bool symbol_is_instantiation(uint32_t tag);
Expand All @@ -382,6 +386,19 @@ using visitor = struct {
writer *, rangemap *, char const *, char const *, char const *, void *);
};

using serialize_visitor = struct {
void (*visit_config)(writer *, block *, uint32_t, bool);
void (*visit_map)(writer *, map *, uint32_t, uint32_t, uint32_t);
void (*visit_list)(writer *, list *, uint32_t, uint32_t, uint32_t);
void (*visit_set)(writer *, set *, uint32_t, uint32_t, uint32_t);
void (*visit_int)(writer *, mpz_t, uint32_t);
void (*visit_float)(writer *, floating *, uint32_t);
void (*visit_bool)(writer *, bool, uint32_t);
void (*visit_string_buffer)(writer *, stringbuffer *, uint32_t);
void (*visit_m_int)(writer *, size_t *, size_t, uint32_t);
void (*visit_range_map)(writer *, rangemap *, uint32_t, uint32_t, uint32_t);
};

void print_map(
writer *, map *, char const *, char const *, char const *, void *);
void print_range_map(
Expand All @@ -392,6 +409,8 @@ void print_list(
writer *, list *, char const *, char const *, char const *, void *);
void visit_children(
block *subject, writer *file, visitor *printer, void *state);
void visit_children_for_serialize(
block *subject, writer *file, serialize_visitor *printer);

stringbuffer *hook_BUFFER_empty(void);
stringbuffer *hook_BUFFER_concat(stringbuffer *buf, string *s);
Expand Down Expand Up @@ -442,4 +461,16 @@ void sfprintf(writer *file, char const *fmt, Args &&...args) {
}
}

template <typename... Args>
void sfwrite(void const *ptr, size_t size, size_t nmemb, writer *file) {
if (file->file) {
fwrite(ptr, size, nmemb, file->file);
} else {
std::string output;
output.resize(size * nmemb);
memcpy(output.data(), ptr, size * nmemb);
hook_BUFFER_concat_raw(file->buffer, output.data(), output.size());
}
}

#endif // RUNTIME_HEADER_H
7 changes: 7 additions & 0 deletions lib/ast/definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,19 @@ std::string get_raw_symbol_name(sort_category cat) {

void kore_definition::insert_reserved_symbols() {
auto mod = kore_module::create("K-RAW-TERM");
// syntax KItem ::= rawTerm(KItem)
auto decl = kore_symbol_declaration::create("rawTerm", true);
// syntax KItem ::= rawKTerm(K)
auto k_decl = kore_symbol_declaration::create("rawKTerm", true);
auto kitem = kore_composite_sort::create("SortKItem");
auto k = kore_composite_sort::create("SortK");

decl->get_symbol()->add_sort(kitem);
decl->get_symbol()->add_argument(kitem);
k_decl->get_symbol()->add_sort(kitem);
k_decl->get_symbol()->add_argument(k);
mod->add_declaration(std::move(decl));
mod->add_declaration(std::move(k_decl));

for (auto const &cat : hooked_sorts_) {
switch (cat.first.cat) {
Expand Down
6 changes: 4 additions & 2 deletions lib/binary/ProofTraceParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,11 @@ void llvm_rewrite_trace::print(
}
}

proof_trace_parser::proof_trace_parser(bool verbose, bool expand_terms)
proof_trace_parser::proof_trace_parser(
bool verbose, bool expand_terms, kore_header const &header)
: verbose_(verbose)
, expand_terms_(expand_terms) { }
, expand_terms_(expand_terms)
, header_(header) { }

std::optional<llvm_rewrite_trace>
proof_trace_parser::parse_proof_trace(std::string const &data) {
Expand Down
81 changes: 81 additions & 0 deletions lib/binary/deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,85 @@ sptr<kore_pattern> deserialize_pattern(std::string const &filename) {
return deserialize_pattern(data.begin(), data.end());
}

// NOLINTNEXTLINE(*-cognitive-complexity)
kore_header::kore_header(FILE *in) {
// NOLINTNEXTLINE(misc-redundant-expression)
if (fgetc(in) != 0x7f || fgetc(in) != 'K' || fgetc(in) != 'R'
|| fgetc(in) != '2') {
throw std::runtime_error("invalid magic");
}
std::array<uint32_t, 4> num_entries{};
if (fread(num_entries.data(), sizeof(uint32_t), 4, in) != 4) {
throw std::runtime_error("invalid table header");
}
uint32_t version = num_entries[0];
uint32_t nstrings = num_entries[1];
uint32_t nsorts = num_entries[2];
uint32_t nsymbols = num_entries[3];

if (version != 1) {
throw std::runtime_error("invalid binary version");
}

std::vector<std::string> strings;
strings.reserve(nstrings);

for (uint32_t i = 0; i < nstrings; ++i) {
uint32_t len = 0;
if (fread(&len, sizeof(uint32_t), 1, in) != 1) {
throw std::runtime_error("invalid string table length");
}
std::string str;
str.resize(len);
if (fread(str.data(), 1, len, in) != len) {
throw std::runtime_error("invalid string table entry");
}
fgetc(in);
strings.push_back(str);
}

std::vector<sptr<kore_sort>> sorts;
sorts.reserve(nsorts);

for (uint32_t i = 0; i < nsorts; ++i) {
uint32_t offset = 0;
if (fread(&offset, sizeof(uint32_t), 1, in) != 1) {
throw std::runtime_error("invalid string table offset in sort table");
}
uint8_t nparams = fgetc(in);
auto sort = kore_composite_sort::create(strings[offset]);
for (uint8_t j = 0; j < nparams; j++) {
uint32_t param_offset = 0;
if (fread(&param_offset, sizeof(uint32_t), 1, in) != 1
|| param_offset >= i) {
throw std::runtime_error("invalid sort table offset in sort table");
}
sort->add_argument(sorts[param_offset]);
}
sorts.push_back(sort);
}

arities_.reserve(nsymbols);
symbols_.reserve(nsymbols);

for (uint32_t i = 0; i < nsymbols; ++i) {
uint32_t offset = 0;
if (fread(&offset, sizeof(uint32_t), 1, in) != 1) {
throw std::runtime_error("invalid string table offset in symbol table");
}
uint8_t nparams = fgetc(in);
uint8_t arity = fgetc(in);
auto symbol = kore_symbol::create(strings[offset]);
for (uint8_t j = 0; j < nparams; j++) {
uint32_t param_offset = 0;
if (fread(&param_offset, sizeof(uint32_t), 1, in) != 1) {
throw std::runtime_error("invalid sort table offset in symbol table");
}
symbol->add_formal_argument(sorts[param_offset]);
}
symbols_.push_back(std::move(symbol));
arities_.push_back(arity);
}
}

} // namespace kllvm
Loading

0 comments on commit 78cf581

Please sign in to comment.