Skip to content

Commit

Permalink
Allow partial function evaluation to be caught by bindings (#967)
Browse files Browse the repository at this point in the history
This PR extends the backend's ability to recover from errors when
partial functions are evaluated in the context of a bindings library.
Rather then unconditionally crashing the entire host process, we throw
an exception that the bindings code can catch and translate to a C-ABI
error structure for the backend to deal with.

Previously, we implemented this behaviour for hooks
(#955); this PR
does the same for general partial function evaluation. The changes are
as follows:
* Refactor (as discussed previously) the implementation of
`finish_rewriting` from LLVM IR into C++.
* Add a new code-generation flag to enable the new error behaviour.
* Modify `llvm-kompile` to pass this flag when compiling a bindings
library.
* Read the flag to change behaviour in the appropriate places in the
runtime library.
* Add a test that the C bindings can safely try to evaluate undefined
partial functions and catch an error when the evaluation fails.

Fixes #925
  • Loading branch information
Baltoli authored Feb 8, 2024
1 parent 3b065dd commit 01b4196
Show file tree
Hide file tree
Showing 12 changed files with 2,251 additions and 82 deletions.
7 changes: 7 additions & 0 deletions bin/llvm-kompile
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@ if [[ "$compile" = "default" ]]; then
dt_dir="${positional_args[1]}"
main="${positional_args[2]}"

# If we're compiling a bindings module rather than an interpreter, then we
# prevent the compiled code from crashing the entire process by passing this
# flag.
if [[ "$main" = "c" ]] || [[ "$main" = "python" ]]; then
codegen_flags+=("--safe-partial")
fi

for arg in "${clang_args[@]}"; do
case "$arg" in
-g)
Expand Down
2 changes: 2 additions & 0 deletions include/kllvm/codegen/Metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ void addKompiledDirSymbol(

void addMutableBytesFlag(llvm::Module &mod, bool enabled, bool debug);

void addSafePartialFlag(llvm::Module &mod, bool enabled, bool debug);

} // namespace kllvm

#endif
9 changes: 0 additions & 9 deletions lib/codegen/EmitConfigParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,6 @@ static llvm::Function *getStrcmp(llvm::Module *module) {
return getOrInsertFunction(module, "strcmp", type);
}

static llvm::Function *getPuts(llvm::Module *module) {
llvm::LLVMContext &Ctx = module->getContext();
auto *type = llvm::FunctionType::get(
llvm::Type::getInt32Ty(Ctx), {llvm::Type::getInt8PtrTy(Ctx)}, false);
return getOrInsertFunction(module, "puts", type);
}

static void
emitGetTagForSymbolName(KOREDefinition *definition, llvm::Module *module) {
llvm::LLVMContext &Ctx = module->getContext();
Expand Down Expand Up @@ -107,8 +100,6 @@ emitGetTagForSymbolName(KOREDefinition *definition, llvm::Module *module) {
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), tag), CurrentBlock);
CurrentBlock = FalseBlock;
}
llvm::Function *Puts = getPuts(module);
llvm::CallInst::Create(Puts, {func->arg_begin()}, "", CurrentBlock);
Phi->addIncoming(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), ERROR_TAG),
CurrentBlock);
Expand Down
42 changes: 28 additions & 14 deletions lib/codegen/Metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,57 @@

namespace kllvm {

static std::string KOMPILED_DIR = "kompiled_directory";
static std::string STRICT_BYTES = "enable_mutable_bytes";
namespace {

void addKompiledDirSymbol(
llvm::Module &mod, std::string const &dir, bool debug) {
std::string KOMPILED_DIR = "kompiled_directory";
std::string STRICT_BYTES = "enable_mutable_bytes";
std::string SAFE_PARTIAL = "safe_partial";

void addBooleanFlag(
llvm::Module &mod, std::string const &name, bool enabled, bool debug) {
auto &ctx = mod.getContext();

auto *str = llvm::ConstantDataArray::getString(ctx, dir, true);
auto *i1_ty = llvm::Type::getInt1Ty(ctx);
auto *enabled_cst = llvm::ConstantInt::getBool(ctx, enabled);

auto *global = mod.getOrInsertGlobal(KOMPILED_DIR, str->getType());
auto *global = mod.getOrInsertGlobal(name, i1_ty);
auto *global_var = llvm::cast<llvm::GlobalVariable>(global);

if (!global_var->hasInitializer()) {
global_var->setInitializer(str);
global_var->setInitializer(enabled_cst);
}

if (debug) {
initDebugGlobal(KOMPILED_DIR, getCharDebugType(), global_var);
initDebugGlobal(STRICT_BYTES, getBoolDebugType(), global_var);
}
}

void addMutableBytesFlag(llvm::Module &mod, bool enabled, bool debug) {
} // namespace

void addKompiledDirSymbol(
llvm::Module &mod, std::string const &dir, bool debug) {
auto &ctx = mod.getContext();

auto *i1_ty = llvm::Type::getInt1Ty(ctx);
auto *enabled_cst = llvm::ConstantInt::getBool(ctx, enabled);
auto *str = llvm::ConstantDataArray::getString(ctx, dir, true);

auto *global = mod.getOrInsertGlobal(STRICT_BYTES, i1_ty);
auto *global = mod.getOrInsertGlobal(KOMPILED_DIR, str->getType());
auto *global_var = llvm::cast<llvm::GlobalVariable>(global);

if (!global_var->hasInitializer()) {
global_var->setInitializer(enabled_cst);
global_var->setInitializer(str);
}

if (debug) {
initDebugGlobal(STRICT_BYTES, getBoolDebugType(), global_var);
initDebugGlobal(KOMPILED_DIR, getCharDebugType(), global_var);
}
}

void addMutableBytesFlag(llvm::Module &mod, bool enabled, bool debug) {
addBooleanFlag(mod, STRICT_BYTES, enabled, debug);
}

void addSafePartialFlag(llvm::Module &mod, bool enabled, bool debug) {
addBooleanFlag(mod, SAFE_PARTIAL, enabled, debug);
}

} // namespace kllvm
58 changes: 2 additions & 56 deletions runtime/finish_rewriting.ll
Original file line number Diff line number Diff line change
Expand Up @@ -5,72 +5,18 @@ target triple = "@BACKEND_TARGET_TRIPLE@"
%block = type { %blockheader, [0 x i64 *] } ; 16-bit layout, 8-bit length, 32-bit tag, children
%mpz = type { i32, i32, i64* }

declare void @printStatistics(i8*, i64)
declare void @printConfiguration(i8*, %block*)
declare void @serializeConfigurationToFile(i8*, %block*)
declare void @exit(i32) #0
declare void @abort() #0
declare void @fclose(i8*)
declare i64 @__gmpz_get_ui(%mpz*)

declare i8* @getStderr()

@stderr = external global i8*

@exit_int_0 = global %mpz { i32 0, i32 0, i64* getelementptr inbounds ([0 x i64], [0 x i64]* @exit_int_0_limbs, i32 0, i32 0) }
@exit_int_0_limbs = global [0 x i64] zeroinitializer

define weak tailcc %mpz* @"eval_LblgetExitCode{SortGeneratedTopCell{}}"(%block*) {
ret %mpz* @exit_int_0
}

@output_file = global i8* zeroinitializer
@statistics = global i1 zeroinitializer
@binary_output = global i1 zeroinitializer
@proof_output = global i1 zeroinitializer
@steps = external thread_local global i64

define void @finish_rewriting(%block* %subject, i1 %error) #0 {
%output = load i8*, i8** @output_file
%outputintptr = ptrtoint i8* %output to i64
%isnull = icmp eq i64 %outputintptr, 0
br i1 %isnull, label %abort, label %print
abort:
%stderr = call i8* @getStderr()
call void @printConfiguration(i8* %stderr, %block* %subject)
call void @abort()
unreachable
print:
%hasStatistics = load i1, i1* @statistics
br i1 %hasStatistics, label %printStatistics, label %printEntry
printStatistics:
%steps = load i64, i64* @steps
call void @printStatistics(i8* %output, i64 %steps)
br label %printEntry
printEntry:
%useBinary = load i1, i1* @binary_output
br i1 %useBinary, label %printBinary, label %printElse
printBinary:
call void @serializeConfigurationToFile(i8* %output, %block* %subject)
br label %tail
printElse:
%useProof = load i1, i1* @proof_output
br i1 %useProof, label %tail, label %printConfig
printConfig:
call void @printConfiguration(i8* %output, %block* %subject)
br label %tail
tail:
br i1 %error, label %exit, label %exitCode
exitCode:
define i32 @get_exit_code(%block* %subject) {
%exit_z = call tailcc %mpz* @"eval_LblgetExitCode{SortGeneratedTopCell{}}"(%block* %subject)
%exit_ul = call i64 @__gmpz_get_ui(%mpz* %exit_z)
%exit_trunc = trunc i64 %exit_ul to i32
br label %exit
exit:
%exit_ui = phi i32 [ %exit_trunc, %exitCode ], [ 113, %tail ]
call void @fclose(i8* %output)
call void @exit(i32 %exit_ui)
unreachable
ret i32 %exit_trunc
}

attributes #0 = { noreturn }
1 change: 1 addition & 0 deletions runtime/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_library(util STATIC
ConfigurationParser.cpp
ConfigurationPrinter.cpp
ConfigurationSerializer.cpp
finish_rewriting.cpp
match_log.cpp
search.cpp
util.cpp
Expand Down
15 changes: 12 additions & 3 deletions runtime/util/ConfigurationParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "kllvm/parser/KOREScanner.h"
#include "runtime/alloc.h"

#include <fmt/format.h>

#include <gmp.h>
#include <map>
#include <variant>
Expand All @@ -17,6 +19,7 @@ using Cache = std::map<std::string, uint32_t>;
static thread_local Cache cache;

extern "C" {

uint32_t getTagForSymbolNameInternal(char const *);

void init_float(floating *result, char const *c_str) {
Expand All @@ -32,12 +35,18 @@ uint32_t getTagForSymbolName(char const *name) {
if (lb != cache.end() && !(cache.key_comp()(s, lb->first))) {
return lb->second;
}

uint32_t const tag = getTagForSymbolNameInternal(s.c_str());

if (tag == ERROR_TAG) {
std::cerr << "No tag found for symbol " << name << ". Maybe attempted to "
<< "evaluate a symbol with no rules?\n";
abort();
auto error_message = fmt::format(
"No tag found for symbol {}. Maybe attempted to evaluate a symbol with "
"no rules?\n",
name);

throw std::runtime_error(error_message);
}

cache.insert(lb, Cache::value_type{s, tag});
return tag;
}
Expand Down
52 changes: 52 additions & 0 deletions runtime/util/finish_rewriting.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include <runtime/header.h>

#include <cstdint>
#include <iostream>
#include <memory>

extern "C" {

FILE *output_file = nullptr;
bool statistics = false;
bool binary_output = false;
bool proof_output = false;

extern int64_t steps;
extern bool safe_partial;

int32_t get_exit_code(block *);

[[noreturn]] void finish_rewriting(block *subject, bool error) {
// This function is responsible for closing output_file when rewriting
// finishes; because it can exit in a few different ways (exceptions,
// std::exit etc.) it's cleaner to set up a smart pointer to do this safely
// for us.
[[maybe_unused]] auto closer
= std::unique_ptr<FILE, decltype(&fclose)>(output_file, fclose);

if (error && safe_partial) {
throw std::runtime_error(
"Attempted to evaluate partial function at an undefined input");
}

if (!output_file) {
throw std::runtime_error(
"Called finish_rewriting with no output file specified");
}

if (statistics) {
printStatistics(output_file, steps);
}

if (!proof_output) {
if (binary_output) {
serializeConfigurationToFile(output_file, subject, true);
} else {
printConfiguration(output_file, subject);
}
}

auto exit_code = error ? 113 : get_exit_code(subject);
std::exit(exit_code);
}
}
36 changes: 36 additions & 0 deletions test/c/Inputs/safe_partial.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "api.h"

#include <assert.h>
#include <stdio.h>

void test_safe_eval(struct kllvm_c_api *api, char const *pattern) {
kore_pattern *one = api->kore_pattern_parse(pattern);
kore_sort *sort_int = api->kore_composite_sort_new("SortInt");

kore_error *err = api->kore_error_new();

char *data;
size_t size;
api->kore_simplify(err, one, sort_int, &data, &size);

assert(
!api->kore_error_is_success(err)
&& "Shouldn't be able to evaluate pattern");

api->kore_pattern_free(one);
api->kore_sort_free(sort_int);
api->kore_error_free(err);
}

int main(int argc, char **argv) {
if (argc <= 1) {
return 1;
}

struct kllvm_c_api api = load_c_api(argv[1]);

api.kllvm_init();

test_safe_eval(&api, "Lblfoo{}(\\dv{SortInt{}}(\"1\"))");
test_safe_eval(&api, "Lblbar{}()");
}
7 changes: 7 additions & 0 deletions test/c/k-files/safe-partial.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module SAFE-PARTIAL
imports INT

syntax Int ::= foo(Int) [function, klabel(foo), symbol]
| bar(Int) [function, klabel(bar), symbol]
rule foo(0) => 0
endmodule
Loading

0 comments on commit 01b4196

Please sign in to comment.