Skip to content

Commit

Permalink
Make ValueType a sum-type, storing in Symbol whether it is Bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
Scott-Guest committed Aug 23, 2023
1 parent bcd1512 commit 6d64c73
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 181 deletions.
60 changes: 37 additions & 23 deletions bindings/python/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@ namespace py = pybind11;
using namespace kllvm;
using namespace kllvm::parser;

namespace PYBIND11_NAMESPACE {
namespace detail {
template <typename... Ts>
struct type_caster<SortCategoryVariant<Ts...>>
: variant_caster<SortCategoryVariant<Ts...>> { };

// Specifies the function used to visit the variant -- `apply_visitor` instead of `visit`
template <>
struct visit_helper<SortCategoryVariant> {
template <typename... Args>
static auto call(Args &&...args) -> decltype(std::visit(args...)) {
return std::visit(args...);
}
};
} // namespace detail
} // namespace PYBIND11_NAMESPACE

// Metaprogramming support for the adapter function between AST print() methods
// and Python's __repr__.
namespace detail {
Expand Down Expand Up @@ -197,28 +214,25 @@ void bind_ast(py::module_ &m) {
.def_property_readonly("attributes", &KOREDefinition::getAttributes);

/* Data Types */

py::enum_<SortCategory>(ast, "SortCategory")
.value("Uncomputed", SortCategory::Uncomputed)
.value("Map", SortCategory::Map)
.value("RangeMap", SortCategory::RangeMap)
.value("List", SortCategory::List)
.value("Set", SortCategory::Set)
.value("Int", SortCategory::Int)
.value("Float", SortCategory::Float)
.value("StringBuffer", SortCategory::StringBuffer)
.value("Bool", SortCategory::Bool)
.value("Symbol", SortCategory::Symbol)
.value("Variable", SortCategory::Variable)
.value("MInt", SortCategory::MInt);

py::class_<ValueType>(ast, "ValueType")
.def(py::init([](SortCategory cat) {
return ValueType{cat, 0};
}))
.def(py::init([](SortCategory cat, uint64_t bits) {
return ValueType{cat, bits};
}));
py::class_<ValueType>(ast, "ValueType").def(py::init<>());
py::class_<ValueTypes::Uncomputed>(ast, "UncomputedVT").def(py::init<>());
py::class_<ValueTypes::Map>(ast, "MapVT").def(py::init<>());
py::class_<ValueTypes::RangeMap>(ast, "RangeMapVT").def(py::init<>());
py::class_<ValueTypes::List>(ast, "ListVT").def(py::init<>());
py::class_<ValueTypes::Set>(ast, "SetVT").def(py::init<>());
py::class_<ValueTypes::Int>(ast, "IntVT").def(py::init<>());
py::class_<ValueTypes::Float>(ast, "FloatVT").def(py::init<>());
py::class_<ValueTypes::StringBuffer>(ast, "StringBufferVT").def(py::init<>());
py::class_<ValueTypes::Bool>(ast, "BoolVT").def(py::init<>());
py::class_<ValueTypes::Symbol>(ast, "SymbolVT")
.def(py::init<bool>())
.def_property_readonly(
"is_bytes", [](ValueTypes::Symbol &cat) { return cat.isBytes; });
py::class_<ValueTypes::Variable>(ast, "VariableVT").def(py::init<>());
py::class_<ValueTypes::MInt>(ast, "MIntVT")
.def(py::init<uint64_t>())
.def_property_readonly(
"bits", [](ValueTypes::MInt &cat) { return cat.bits; });

/* Sorts */

Expand All @@ -242,7 +256,7 @@ void bind_ast(py::module_ &m) {
ast, "CompositeSort", sort_base)
.def(
py::init(&KORECompositeSort::Create), py::arg("name"),
py::arg("cat") = ValueType{SortCategory::Uncomputed, 0})
py::arg("cat") = ValueTypes::Uncomputed())
.def_property_readonly("name", &KORECompositeSort::getName)
.def("add_argument", &KORECompositeSort::addArgument)
.def_property_readonly("arguments", &KORECompositeSort::getArguments);
Expand Down
115 changes: 90 additions & 25 deletions include/kllvm/ast/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,34 +111,100 @@ class KORESortVariable : public KORESort {
: name(Name) { }
};

enum class SortCategory {
Uncomputed,
Map,
List,
Set,
Int,
Float,
StringBuffer,
Bool,
Symbol,
Variable,
MInt,
RangeMap
// The various syntactic categories of an LLVM backend term at runtime
namespace ValueTypes {
struct Uncomputed {
bool operator<(Uncomputed) const { return false; }
};
struct Map {
bool operator<(Map) const { return false; }
};
struct List {
bool operator<(List) const { return false; }
};
struct Set {
bool operator<(Set) const { return false; }
};
struct Int {
bool operator<(Int) const { return false; }
};
struct Float {
bool operator<(Float) const { return false; }
};
struct StringBuffer {
bool operator<(StringBuffer) const { return false; }
};
struct Bool {
bool operator<(Bool) const { return false; }
};
struct Symbol {
Symbol(bool isBytes)
: isBytes(isBytes) { }
bool isBytes;

// represents the syntactic category of an LLVM backend term at runtime
struct ValueType {
// fundamental category of the term
SortCategory cat;
// if this is an MInt, the number of bits in the MInt
bool operator<(Symbol other) const { return isBytes < other.isBytes; }
};
struct Variable {
bool operator<(Variable) const { return false; }
};
struct MInt {
MInt(uint64_t bits)
: bits(bits){};
uint64_t bits;

bool operator<(const ValueType &that) const {
return std::make_tuple(this->cat, this->bits)
< std::make_tuple(that.cat, that.bits);
bool operator<(MInt other) const { return bits < other.bits; }
};
struct RangeMap {
bool operator<(RangeMap) const { return false; }
};
} // namespace ValueTypes

// Each class' index in the ValueType variant
//
// This is used to provide switching over ValueType through
//
// switch (valType.cat())) {
// case SortCategory::Uncomputed:
// ...
// }
//
// This is useful because:
// - switching directly on ValueType::index() is unreadable and won't warn for inexhaustiveness
// - std::visit (or std::holds_alternative) is cumbersome to use and has unneeded overhead
// given that almost all of the SortCategories don't hold data.
//
enum class SortCategory : size_t {
Uncomputed = 0,
Map = 1,
List = 2,
Set = 3,
Int = 4,
Float = 5,
StringBuffer = 6,
Bool = 7,
Symbol = 8,
Variable = 9,
MInt = 10,
RangeMap = 11
};

template <typename... Ts>
class SortCategoryVariant : public std::variant<Ts...> {
public:
using std::variant<Ts...>::variant;

__attribute__((always_inline)) SortCategory cat() const {
return static_cast<SortCategory>(this->index());
}
};

// WARNING: Do not modify this type without updating SortCategory above!
using ValueType = SortCategoryVariant<
ValueTypes::Uncomputed, ValueTypes::Map, ValueTypes::List, ValueTypes::Set,
ValueTypes::Int, ValueTypes::Float, ValueTypes::StringBuffer,
ValueTypes::Bool, ValueTypes::Symbol, ValueTypes::Variable,
ValueTypes::MInt, ValueTypes::RangeMap>;

class KOREDefinition;

class KORECompositeSort : public KORESort {
Expand All @@ -148,8 +214,8 @@ class KORECompositeSort : public KORESort {
ValueType category;

public:
static sptr<KORECompositeSort> Create(
const std::string &Name, ValueType Cat = {SortCategory::Uncomputed, 0}) {
static sptr<KORECompositeSort>
Create(const std::string &Name, ValueType Cat = ValueTypes::Uncomputed()) {
return sptr<KORECompositeSort>(new KORECompositeSort(Name, Cat));
}

Expand Down Expand Up @@ -884,8 +950,7 @@ class KOREDefinition {

using KORECompositeSortDeclarationMapType
= std::map<std::string, KORECompositeSortDeclaration *>;
using KORECompositeSortMapType
= std::map<std::string, sptr<KORECompositeSort>>;
using KORECompositeSortMapType = std::map<ValueType, sptr<KORECompositeSort>>;

using KORESymbolDeclarationMapType
= std::map<std::string, KORESymbolDeclaration *>;
Expand Down
6 changes: 3 additions & 3 deletions include/kllvm/codegen/DecisionParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ struct PartialStep {
DecisionNode *parseYamlDecisionTreeFromString(
llvm::Module *, std::string yaml,
const std::map<std::string, KORESymbol *> &syms,
const std::map<std::string, sptr<KORECompositeSort>> &hookedSorts);
const std::map<ValueType, sptr<KORECompositeSort>> &hookedSorts);
DecisionNode *parseYamlDecisionTree(
llvm::Module *, std::string filename,
const std::map<std::string, KORESymbol *> &syms,
const std::map<std::string, sptr<KORECompositeSort>> &hookedSorts);
const std::map<ValueType, sptr<KORECompositeSort>> &hookedSorts);
PartialStep parseYamlSpecialDecisionTree(
llvm::Module *, std::string filename,
const std::map<std::string, KORESymbol *> &syms,
const std::map<std::string, sptr<KORECompositeSort>> &hookedSorts);
const std::map<ValueType, sptr<KORECompositeSort>> &hookedSorts);

} // namespace kllvm

Expand Down
63 changes: 31 additions & 32 deletions lib/ast/AST.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ sptr<KORESort> KORECompositeSort::substitute(const substitution &subst) {
}

ValueType KORECompositeSort::getCategory(KOREDefinition *definition) {
if (category.cat != SortCategory::Uncomputed)
if (category.cat() != SortCategory::Uncomputed)
return category;
std::string name = getHook(definition);
if (name == "MINT.MInt") {
Expand Down Expand Up @@ -187,36 +187,33 @@ std::string KORECompositeSort::getHook(KOREDefinition *definition) {
}

ValueType KORECompositeSort::getCategory(std::string name) {
SortCategory category;
uint64_t bits = 0;
if (name == "MAP.Map")
category = SortCategory::Map;
else if (name == "RANGEMAP.RangeMap")
category = SortCategory::RangeMap;
else if (name == "LIST.List")
category = SortCategory::List;
else if (name == "SET.Set")
category = SortCategory::Set;
else if (name == "ARRAY.Array")
category = SortCategory::Symbol; // ARRAY is implemented in K
else if (name == "INT.Int")
category = SortCategory::Int;
else if (name == "FLOAT.Float")
category = SortCategory::Float;
else if (name == "BUFFER.StringBuffer")
category = SortCategory::StringBuffer;
else if (name == "BOOL.Bool")
category = SortCategory::Bool;
else if (name == "KVAR.KVar")
category = SortCategory::Variable;
return ValueTypes::Map();
if (name == "RANGEMAP.RangeMap")
return ValueTypes::RangeMap();
if (name == "LIST.List")
return ValueTypes::List();
if (name == "SET.Set")
return ValueTypes::Set();
if (name == "ARRAY.Array")
return ValueTypes::Symbol(false); // ARRAY is implemented in K
if (name == "INT.Int")
return ValueTypes::Int();
if (name == "FLOAT.Float")
return ValueTypes::Float();
if (name == "BUFFER.StringBuffer")
return ValueTypes::StringBuffer();
if (name == "BOOL.Bool")
return ValueTypes::Bool();
if (name == "KVAR.KVar")
return ValueTypes::Variable();
if (name == "BYTES.Bytes")
return ValueTypes::Symbol(true);
// we expect the "hook" of a MInt to be of the form "MINT.MInt N" for some
// bitwidth N
else if (name.substr(0, 10) == "MINT.MInt ") {
category = SortCategory::MInt;
bits = std::stoi(name.substr(10));
} else
category = SortCategory::Symbol;
return {category, bits};
if (name.substr(0, 10) == "MINT.MInt ")
return ValueTypes::MInt(std::stoi(name.substr(10)));
return ValueTypes::Symbol(false);
}

void KORESymbol::addArgument(sptr<KORESort> Argument) {
Expand Down Expand Up @@ -247,7 +244,7 @@ std::string KORESymbol::layoutString(KOREDefinition *definition) const {
for (auto arg : arguments) {
auto sort = dynamic_cast<KORECompositeSort *>(arg.get());
ValueType cat = sort->getCategory(definition);
switch (cat.cat) {
switch (cat.cat()) {
case SortCategory::Map: result.push_back('1'); break;
case SortCategory::RangeMap: result.push_back('b'); break;
case SortCategory::List: result.push_back('2'); break;
Expand All @@ -258,7 +255,9 @@ std::string KORESymbol::layoutString(KOREDefinition *definition) const {
case SortCategory::Bool: result.push_back('7'); break;
case SortCategory::Variable: result.push_back('8'); break;
case SortCategory::MInt:
result.append("_" + std::to_string(cat.bits) + "_");
result.append(
"_" + std::to_string(std::get_if<ValueTypes::MInt>(&cat)->bits)
+ "_");
case SortCategory::Symbol: result.push_back('0'); break;
case SortCategory::Uncomputed: abort();
}
Expand Down Expand Up @@ -1859,13 +1858,13 @@ void KOREDefinition::preprocess() {
for (auto &sort : symbol->getArguments()) {
if (sort->isConcrete()) {
hookedSorts[dynamic_cast<KORECompositeSort *>(sort.get())
->getHook(this)]
->getCategory(this)]
= std::dynamic_pointer_cast<KORECompositeSort>(sort);
}
}
if (symbol->getSort()->isConcrete()) {
hookedSorts[dynamic_cast<KORECompositeSort *>(symbol->getSort().get())
->getHook(this)]
->getCategory(this)]
= std::dynamic_pointer_cast<KORECompositeSort>(symbol->getSort());
}
if (!symbol->isConcrete()) {
Expand Down
5 changes: 3 additions & 2 deletions lib/codegen/CreateStaticTerm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ CreateStaticTerm::operator()(KOREPattern *pattern) {
if (symbolDecl->getAttributes().count("sortInjection")
&& dynamic_cast<KORECompositeSort *>(symbol->getArguments()[0].get())
->getCategory(Definition)
.cat
.cat()
== SortCategory::Symbol) {
std::pair<llvm::Constant *, bool> val
= (*this)(constructor->getArguments()[0].get());
Expand All @@ -146,7 +146,7 @@ CreateStaticTerm::operator()(KOREPattern *pattern) {
llvm::Constant *
CreateStaticTerm::createToken(KORECompositeSort *sort, std::string contents) {
ValueType cat = sort->getCategory(Definition);
switch (cat.cat) {
switch (cat.cat()) {
case SortCategory::Map:
case SortCategory::RangeMap:
case SortCategory::List:
Expand Down Expand Up @@ -339,6 +339,7 @@ CreateStaticTerm::createToken(KORECompositeSort *sort, std::string contents) {
}
case SortCategory::Uncomputed: abort();
}
abort();
}

} // namespace kllvm
Loading

0 comments on commit 6d64c73

Please sign in to comment.