diff --git a/bindings/python/ast.cpp b/bindings/python/ast.cpp index ff0850f33..a4f3f3b18 100644 --- a/bindings/python/ast.cpp +++ b/bindings/python/ast.cpp @@ -16,6 +16,23 @@ namespace py = pybind11; using namespace kllvm; using namespace kllvm::parser; +namespace PYBIND11_NAMESPACE { +namespace detail { +template +struct type_caster> + : variant_caster> { }; + +// Specifies the function used to visit the variant -- `apply_visitor` instead of `visit` +template <> +struct visit_helper { + template + static auto call(Args &&...args) -> decltype(std::visit(args...)) { + return std::visit(args...); + } +}; +} // namespace detail +} // namespace PYBIND11_NAMESPACE + // Metaprogramming support for the adapter function between AST print() methods // and Python's __repr__. namespace detail { @@ -197,28 +214,25 @@ void bind_ast(py::module_ &m) { .def_property_readonly("attributes", &KOREDefinition::getAttributes); /* Data Types */ - - py::enum_(ast, "SortCategory") - .value("Uncomputed", SortCategory::Uncomputed) - .value("Map", SortCategory::Map) - .value("RangeMap", SortCategory::RangeMap) - .value("List", SortCategory::List) - .value("Set", SortCategory::Set) - .value("Int", SortCategory::Int) - .value("Float", SortCategory::Float) - .value("StringBuffer", SortCategory::StringBuffer) - .value("Bool", SortCategory::Bool) - .value("Symbol", SortCategory::Symbol) - .value("Variable", SortCategory::Variable) - .value("MInt", SortCategory::MInt); - - py::class_(ast, "ValueType") - .def(py::init([](SortCategory cat) { - return ValueType{cat, 0}; - })) - .def(py::init([](SortCategory cat, uint64_t bits) { - return ValueType{cat, bits}; - })); + py::class_(ast, "ValueType").def(py::init<>()); + py::class_(ast, "UncomputedVT").def(py::init<>()); + py::class_(ast, "MapVT").def(py::init<>()); + py::class_(ast, "RangeMapVT").def(py::init<>()); + py::class_(ast, "ListVT").def(py::init<>()); + py::class_(ast, "SetVT").def(py::init<>()); + py::class_(ast, "IntVT").def(py::init<>()); + py::class_(ast, "FloatVT").def(py::init<>()); + py::class_(ast, "StringBufferVT").def(py::init<>()); + py::class_(ast, "BoolVT").def(py::init<>()); + py::class_(ast, "SymbolVT") + .def(py::init()) + .def_property_readonly( + "is_bytes", [](ValueTypes::Symbol &cat) { return cat.isBytes; }); + py::class_(ast, "VariableVT").def(py::init<>()); + py::class_(ast, "MIntVT") + .def(py::init()) + .def_property_readonly( + "bits", [](ValueTypes::MInt &cat) { return cat.bits; }); /* Sorts */ @@ -242,7 +256,7 @@ void bind_ast(py::module_ &m) { ast, "CompositeSort", sort_base) .def( py::init(&KORECompositeSort::Create), py::arg("name"), - py::arg("cat") = ValueType{SortCategory::Uncomputed, 0}) + py::arg("cat") = ValueTypes::Uncomputed()) .def_property_readonly("name", &KORECompositeSort::getName) .def("add_argument", &KORECompositeSort::addArgument) .def_property_readonly("arguments", &KORECompositeSort::getArguments); diff --git a/include/kllvm/ast/AST.h b/include/kllvm/ast/AST.h index e0eb87d6b..f9092dec7 100644 --- a/include/kllvm/ast/AST.h +++ b/include/kllvm/ast/AST.h @@ -111,34 +111,100 @@ class KORESortVariable : public KORESort { : name(Name) { } }; -enum class SortCategory { - Uncomputed, - Map, - List, - Set, - Int, - Float, - StringBuffer, - Bool, - Symbol, - Variable, - MInt, - RangeMap +// The various syntactic categories of an LLVM backend term at runtime +namespace ValueTypes { +struct Uncomputed { + bool operator<(Uncomputed) const { return false; } }; +struct Map { + bool operator<(Map) const { return false; } +}; +struct List { + bool operator<(List) const { return false; } +}; +struct Set { + bool operator<(Set) const { return false; } +}; +struct Int { + bool operator<(Int) const { return false; } +}; +struct Float { + bool operator<(Float) const { return false; } +}; +struct StringBuffer { + bool operator<(StringBuffer) const { return false; } +}; +struct Bool { + bool operator<(Bool) const { return false; } +}; +struct Symbol { + Symbol(bool isBytes) + : isBytes(isBytes) { } + bool isBytes; -// represents the syntactic category of an LLVM backend term at runtime -struct ValueType { - // fundamental category of the term - SortCategory cat; - // if this is an MInt, the number of bits in the MInt + bool operator<(Symbol other) const { return isBytes < other.isBytes; } +}; +struct Variable { + bool operator<(Variable) const { return false; } +}; +struct MInt { + MInt(uint64_t bits) + : bits(bits){}; uint64_t bits; - bool operator<(const ValueType &that) const { - return std::make_tuple(this->cat, this->bits) - < std::make_tuple(that.cat, that.bits); + bool operator<(MInt other) const { return bits < other.bits; } +}; +struct RangeMap { + bool operator<(RangeMap) const { return false; } +}; +} // namespace ValueTypes + +// Each class' index in the ValueType variant +// +// This is used to provide switching over ValueType through +// +// switch (valType.cat())) { +// case SortCategory::Uncomputed: +// ... +// } +// +// This is useful because: +// - switching directly on ValueType::index() is unreadable and won't warn for inexhaustiveness +// - std::visit (or std::holds_alternative) is cumbersome to use and has unneeded overhead +// given that almost all of the SortCategories don't hold data. +// +enum class SortCategory : size_t { + Uncomputed = 0, + Map = 1, + List = 2, + Set = 3, + Int = 4, + Float = 5, + StringBuffer = 6, + Bool = 7, + Symbol = 8, + Variable = 9, + MInt = 10, + RangeMap = 11 +}; + +template +class SortCategoryVariant : public std::variant { +public: + using std::variant::variant; + + __attribute__((always_inline)) SortCategory cat() const { + return static_cast(this->index()); } }; +// WARNING: Do not modify this type without updating SortCategory above! +using ValueType = SortCategoryVariant< + ValueTypes::Uncomputed, ValueTypes::Map, ValueTypes::List, ValueTypes::Set, + ValueTypes::Int, ValueTypes::Float, ValueTypes::StringBuffer, + ValueTypes::Bool, ValueTypes::Symbol, ValueTypes::Variable, + ValueTypes::MInt, ValueTypes::RangeMap>; + class KOREDefinition; class KORECompositeSort : public KORESort { @@ -148,8 +214,8 @@ class KORECompositeSort : public KORESort { ValueType category; public: - static sptr Create( - const std::string &Name, ValueType Cat = {SortCategory::Uncomputed, 0}) { + static sptr + Create(const std::string &Name, ValueType Cat = ValueTypes::Uncomputed()) { return sptr(new KORECompositeSort(Name, Cat)); } @@ -884,8 +950,7 @@ class KOREDefinition { using KORECompositeSortDeclarationMapType = std::map; - using KORECompositeSortMapType - = std::map>; + using KORECompositeSortMapType = std::map>; using KORESymbolDeclarationMapType = std::map; diff --git a/include/kllvm/codegen/DecisionParser.h b/include/kllvm/codegen/DecisionParser.h index 82c41b880..fc345856e 100644 --- a/include/kllvm/codegen/DecisionParser.h +++ b/include/kllvm/codegen/DecisionParser.h @@ -26,15 +26,15 @@ struct PartialStep { DecisionNode *parseYamlDecisionTreeFromString( llvm::Module *, std::string yaml, const std::map &syms, - const std::map> &hookedSorts); + const std::map> &hookedSorts); DecisionNode *parseYamlDecisionTree( llvm::Module *, std::string filename, const std::map &syms, - const std::map> &hookedSorts); + const std::map> &hookedSorts); PartialStep parseYamlSpecialDecisionTree( llvm::Module *, std::string filename, const std::map &syms, - const std::map> &hookedSorts); + const std::map> &hookedSorts); } // namespace kllvm diff --git a/lib/ast/AST.cpp b/lib/ast/AST.cpp index b78797aad..09f6aeb8e 100644 --- a/lib/ast/AST.cpp +++ b/lib/ast/AST.cpp @@ -151,7 +151,7 @@ sptr KORECompositeSort::substitute(const substitution &subst) { } ValueType KORECompositeSort::getCategory(KOREDefinition *definition) { - if (category.cat != SortCategory::Uncomputed) + if (category.cat() != SortCategory::Uncomputed) return category; std::string name = getHook(definition); if (name == "MINT.MInt") { @@ -187,36 +187,33 @@ std::string KORECompositeSort::getHook(KOREDefinition *definition) { } ValueType KORECompositeSort::getCategory(std::string name) { - SortCategory category; - uint64_t bits = 0; if (name == "MAP.Map") - category = SortCategory::Map; - else if (name == "RANGEMAP.RangeMap") - category = SortCategory::RangeMap; - else if (name == "LIST.List") - category = SortCategory::List; - else if (name == "SET.Set") - category = SortCategory::Set; - else if (name == "ARRAY.Array") - category = SortCategory::Symbol; // ARRAY is implemented in K - else if (name == "INT.Int") - category = SortCategory::Int; - else if (name == "FLOAT.Float") - category = SortCategory::Float; - else if (name == "BUFFER.StringBuffer") - category = SortCategory::StringBuffer; - else if (name == "BOOL.Bool") - category = SortCategory::Bool; - else if (name == "KVAR.KVar") - category = SortCategory::Variable; + return ValueTypes::Map(); + if (name == "RANGEMAP.RangeMap") + return ValueTypes::RangeMap(); + if (name == "LIST.List") + return ValueTypes::List(); + if (name == "SET.Set") + return ValueTypes::Set(); + if (name == "ARRAY.Array") + return ValueTypes::Symbol(false); // ARRAY is implemented in K + if (name == "INT.Int") + return ValueTypes::Int(); + if (name == "FLOAT.Float") + return ValueTypes::Float(); + if (name == "BUFFER.StringBuffer") + return ValueTypes::StringBuffer(); + if (name == "BOOL.Bool") + return ValueTypes::Bool(); + if (name == "KVAR.KVar") + return ValueTypes::Variable(); + if (name == "BYTES.Bytes") + return ValueTypes::Symbol(true); // we expect the "hook" of a MInt to be of the form "MINT.MInt N" for some // bitwidth N - else if (name.substr(0, 10) == "MINT.MInt ") { - category = SortCategory::MInt; - bits = std::stoi(name.substr(10)); - } else - category = SortCategory::Symbol; - return {category, bits}; + if (name.substr(0, 10) == "MINT.MInt ") + return ValueTypes::MInt(std::stoi(name.substr(10))); + return ValueTypes::Symbol(false); } void KORESymbol::addArgument(sptr Argument) { @@ -247,7 +244,7 @@ std::string KORESymbol::layoutString(KOREDefinition *definition) const { for (auto arg : arguments) { auto sort = dynamic_cast(arg.get()); ValueType cat = sort->getCategory(definition); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Map: result.push_back('1'); break; case SortCategory::RangeMap: result.push_back('b'); break; case SortCategory::List: result.push_back('2'); break; @@ -258,7 +255,9 @@ std::string KORESymbol::layoutString(KOREDefinition *definition) const { case SortCategory::Bool: result.push_back('7'); break; case SortCategory::Variable: result.push_back('8'); break; case SortCategory::MInt: - result.append("_" + std::to_string(cat.bits) + "_"); + result.append( + "_" + std::to_string(std::get_if(&cat)->bits) + + "_"); case SortCategory::Symbol: result.push_back('0'); break; case SortCategory::Uncomputed: abort(); } @@ -1859,13 +1858,13 @@ void KOREDefinition::preprocess() { for (auto &sort : symbol->getArguments()) { if (sort->isConcrete()) { hookedSorts[dynamic_cast(sort.get()) - ->getHook(this)] + ->getCategory(this)] = std::dynamic_pointer_cast(sort); } } if (symbol->getSort()->isConcrete()) { hookedSorts[dynamic_cast(symbol->getSort().get()) - ->getHook(this)] + ->getCategory(this)] = std::dynamic_pointer_cast(symbol->getSort()); } if (!symbol->isConcrete()) { diff --git a/lib/codegen/CreateStaticTerm.cpp b/lib/codegen/CreateStaticTerm.cpp index ee7f10f15..da8e3ca6e 100644 --- a/lib/codegen/CreateStaticTerm.cpp +++ b/lib/codegen/CreateStaticTerm.cpp @@ -119,7 +119,7 @@ CreateStaticTerm::operator()(KOREPattern *pattern) { if (symbolDecl->getAttributes().count("sortInjection") && dynamic_cast(symbol->getArguments()[0].get()) ->getCategory(Definition) - .cat + .cat() == SortCategory::Symbol) { std::pair val = (*this)(constructor->getArguments()[0].get()); @@ -146,7 +146,7 @@ CreateStaticTerm::operator()(KOREPattern *pattern) { llvm::Constant * CreateStaticTerm::createToken(KORECompositeSort *sort, std::string contents) { ValueType cat = sort->getCategory(Definition); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -339,6 +339,7 @@ CreateStaticTerm::createToken(KORECompositeSort *sort, std::string contents) { } case SortCategory::Uncomputed: abort(); } + abort(); } } // namespace kllvm diff --git a/lib/codegen/CreateTerm.cpp b/lib/codegen/CreateTerm.cpp index bcb82aacc..e9ea3da3d 100644 --- a/lib/codegen/CreateTerm.cpp +++ b/lib/codegen/CreateTerm.cpp @@ -145,7 +145,7 @@ std::string BLOCKHEADER_STRUCT = "blockheader"; llvm::Type *getParamType(ValueType sort, llvm::Module *Module) { llvm::Type *type = getValueType(sort, Module); - switch (sort.cat) { + switch (sort.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -160,7 +160,7 @@ llvm::StructType *getBlockType(llvm::Module *Module) { } llvm::Type *getValueType(ValueType sort, llvm::Module *Module) { - switch (sort.cat) { + switch (sort.cat()) { case SortCategory::Map: return getTypeByName(Module, MAP_STRUCT); case SortCategory::RangeMap: return getTypeByName(Module, RANGEMAP_STRUCT); case SortCategory::List: return getTypeByName(Module, LIST_STRUCT); @@ -173,7 +173,8 @@ llvm::Type *getValueType(ValueType sort, llvm::Module *Module) { return llvm::PointerType::getUnqual(getTypeByName(Module, BUFFER_STRUCT)); case SortCategory::Bool: return llvm::Type::getInt1Ty(Module->getContext()); case SortCategory::MInt: - return llvm::IntegerType::get(Module->getContext(), sort.bits); + return llvm::IntegerType::get( + Module->getContext(), std::get_if(&sort)->bits); case SortCategory::Symbol: case SortCategory::Variable: return llvm::PointerType::getUnqual(getTypeByName(Module, BLOCK_STRUCT)); @@ -423,8 +424,10 @@ llvm::Value *CreateTerm::createHook( auto sort = dynamic_cast( pattern->getConstructor()->getArguments()[0].get()); ValueType cat = sort->getCategory(Definition); + assert(cat.cat() == SortCategory::MInt); auto Type = getValueType(cat, Module); - size_t nwords = (cat.bits + 63) / 64; + uint64_t bits = std::get_if(&cat)->bits; + size_t nwords = (bits + 63) / 64; if (nwords == 0) { auto staticTerm = new CreateStaticTerm(Definition, Module); return staticTerm->createToken(sort, "0"); @@ -435,7 +438,7 @@ llvm::Value *CreateTerm::createHook( CurrentBlock, "koreAllocAlwaysGC"); if (nwords == 1) { llvm::Value *Word; - if (cat.bits == 64) { + if (bits == 64) { Word = mint; } else { Word = new llvm::ZExtInst( @@ -460,11 +463,10 @@ llvm::Value *CreateTerm::createHook( } auto result = llvm::CallInst::Create( getOrInsertFunction( - Module, "hook_MINT_import", - getValueType({SortCategory::Int, 0}, Module), + Module, "hook_MINT_import", getValueType(ValueTypes::Int(), Module), llvm::Type::getInt64PtrTy(Ctx), llvm::Type::getInt64Ty(Ctx), llvm::Type::getInt1Ty(Ctx)), - {Ptr, llvm::ConstantInt::get(llvm::Type::getInt64Ty(Ctx), cat.bits), + {Ptr, llvm::ConstantInt::get(llvm::Type::getInt64Ty(Ctx), bits), llvm::ConstantInt::getFalse(Ctx)}, "hook_MINT_uvalue", CurrentBlock); setDebugLoc(result); @@ -475,8 +477,10 @@ llvm::Value *CreateTerm::createHook( auto sort = dynamic_cast( pattern->getConstructor()->getArguments()[0].get()); ValueType cat = sort->getCategory(Definition); + assert(cat.cat() == SortCategory::MInt); auto Type = getValueType(cat, Module); - size_t nwords = (cat.bits + 63) / 64; + uint64_t bits = std::get_if(&cat)->bits; + size_t nwords = (bits + 63) / 64; if (nwords == 0) { auto staticTerm = new CreateStaticTerm(Definition, Module); return staticTerm->createToken(sort, "0"); @@ -487,7 +491,7 @@ llvm::Value *CreateTerm::createHook( CurrentBlock, "koreAllocAlwaysGC"); if (nwords == 1) { llvm::Value *Word; - if (cat.bits == 64) { + if (bits == 64) { Word = mint; } else { Word = new llvm::SExtInst( @@ -512,11 +516,10 @@ llvm::Value *CreateTerm::createHook( } auto result = llvm::CallInst::Create( getOrInsertFunction( - Module, "hook_MINT_import", - getValueType({SortCategory::Int, 0}, Module), + Module, "hook_MINT_import", getValueType(ValueTypes::Int(), Module), llvm::Type::getInt64PtrTy(Ctx), llvm::Type::getInt64Ty(Ctx), llvm::Type::getInt1Ty(Ctx)), - {Ptr, llvm::ConstantInt::get(llvm::Type::getInt64Ty(Ctx), cat.bits), + {Ptr, llvm::ConstantInt::get(llvm::Type::getInt64Ty(Ctx), bits), llvm::ConstantInt::getTrue(Ctx)}, "hook_MINT_svalue", CurrentBlock); setDebugLoc(result); @@ -527,22 +530,24 @@ llvm::Value *CreateTerm::createHook( pattern->getConstructor()->getSort().get()) ->getCategory(Definition); auto Type = getValueType(cat, Module); + uint64_t bits = std::get_if(&cat)->bits; + assert(cat.cat() == SortCategory::MInt); llvm::Instruction *Ptr = llvm::CallInst::Create( getOrInsertFunction( Module, "hook_MINT_export", llvm::Type::getInt64PtrTy(Ctx), - getValueType({SortCategory::Int, 0}, Module), + getValueType(ValueTypes::Int(), Module), llvm::Type::getInt64Ty(Ctx)), - {mpz, llvm::ConstantInt::get(llvm::Type::getInt64Ty(Ctx), cat.bits)}, - "ptr", CurrentBlock); + {mpz, llvm::ConstantInt::get(llvm::Type::getInt64Ty(Ctx), bits)}, "ptr", + CurrentBlock); setDebugLoc(Ptr); - size_t nwords = (cat.bits + 63) / 64; + size_t nwords = (bits + 63) / 64; llvm::Value *result = llvm::ConstantInt::get(Type, 0); if (nwords == 0) { return result; } else if (nwords == 1) { auto Word = new llvm::LoadInst( llvm::Type::getInt64Ty(Ctx), Ptr, "word", CurrentBlock); - if (cat.bits == 64) { + if (bits == 64) { return Word; } else { return new llvm::TruncInst( @@ -764,7 +769,7 @@ llvm::Value *CreateTerm::createFunctionCall( auto concreteSort = dynamic_cast(sort.get()); llvm::Value *arg = createAllocation(pattern->getArguments()[i++].get()).first; - switch (concreteSort->getCategory(Definition).cat) { + switch (concreteSort->getCategory(Definition).cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -791,7 +796,7 @@ llvm::Value *CreateTerm::createFunctionCall( llvm::Type *returnType = getValueType(returnCat, Module); std::vector types; bool collection = false; - switch (returnCat.cat) { + switch (returnCat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -975,7 +980,7 @@ CreateTerm::createAllocation(KOREPattern *pattern) { symbolDecl->getAttributes().count("sortInjection") && dynamic_cast(symbol->getArguments()[0].get()) ->getCategory(Definition) - .cat + .cat() == SortCategory::Symbol) { std::pair val = createAllocation(constructor->getArguments()[0].get()); @@ -983,7 +988,7 @@ CreateTerm::createAllocation(KOREPattern *pattern) { llvm::Instruction *Tag = llvm::CallInst::Create( getOrInsertFunction( Module, "getTag", llvm::Type::getInt32Ty(Ctx), - getValueType({SortCategory::Symbol, 0}, Module)), + getValueType(ValueTypes::Symbol(false), Module)), val.first, "tag", CurrentBlock); setDebugLoc(Tag); auto inj = Definition->getInjSymbol(); @@ -1087,7 +1092,7 @@ bool makeFunction( sort->print(Out); llvm::Type *paramType = getValueType(cat, Module); debugArgs.push_back(getDebugType(cat, Out.str())); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -1103,7 +1108,7 @@ bool makeFunction( } ValueType returnCat = termType(pattern, params, definition); auto returnType = getValueType(returnCat, Module); - switch (returnCat.cat) { + switch (returnCat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -1184,14 +1189,14 @@ bool makeFunction( llvm::Type::getInt8PtrTy(Module->getContext()), llvm::Type::getInt8PtrTy(Module->getContext())), {outputFile, varname}); - if (cat.cat == SortCategory::Symbol - || cat.cat == SortCategory::Variable) { + if (cat.cat() == SortCategory::Symbol + || cat.cat() == SortCategory::Variable) { ir->CreateCall( getOrInsertFunction( Module, "serializeTermToFile", llvm::Type::getVoidTy(Module->getContext()), llvm::Type::getInt8PtrTy(Module->getContext()), - getValueType({SortCategory::Symbol, 0}, Module), + getValueType(ValueTypes::Symbol(false), Module), llvm::Type::getInt8PtrTy(Module->getContext())), {outputFile, val, sortptr}); } else if (val->getType()->isIntegerTy()) { @@ -1226,7 +1231,7 @@ bool makeFunction( Module, "serializeConfigurationToFile", llvm::Type::getVoidTy(Module->getContext()), llvm::Type::getInt8PtrTy(Module->getContext()), - getValueType({SortCategory::Symbol, 0}, Module)), + getValueType(ValueTypes::Symbol(false), Module)), {outputFile, retval}); writeUInt64(outputFile, Module, 0xcccccccccccccccc, TrueBlock); @@ -1235,7 +1240,7 @@ bool makeFunction( } if (bigStep) { - llvm::Type *blockType = getValueType({SortCategory::Symbol, 0}, Module); + llvm::Type *blockType = getValueType(ValueTypes::Symbol(false), Module); llvm::Function *step = getOrInsertFunction( Module, "k_step", llvm::FunctionType::get(blockType, {blockType}, false)); @@ -1291,7 +1296,7 @@ std::string makeApplyRuleFunction( sort->print(Out); llvm::Type *paramType = getValueType(cat, Module); debugArgs.push_back(getDebugType(cat, Out.str())); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -1306,7 +1311,7 @@ std::string makeApplyRuleFunction( paramNames.push_back(entry.first); } llvm::FunctionType *funcType = llvm::FunctionType::get( - getValueType({SortCategory::Symbol, 0}, Module), paramTypes, false); + getValueType(ValueTypes::Symbol(false), Module), paramTypes, false); std::string name = "apply_rule_" + std::to_string(axiom->getOrdinal()); makeFunction( @@ -1318,7 +1323,7 @@ std::string makeApplyRuleFunction( initDebugFunction( name, name, getDebugFunctionType( - getDebugType({SortCategory::Symbol, 0}, "SortGeneratedTopCell{}"), + getDebugType(ValueTypes::Symbol(false), "SortGeneratedTopCell{}"), debugArgs), definition, applyRule); applyRule->setCallingConv(llvm::CallingConv::Tail); @@ -1343,7 +1348,7 @@ std::string makeApplyRuleFunction( auto sort = dynamic_cast(residual.pattern->getSort().get()); auto cat = sort->getCategory(definition); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -1360,7 +1365,7 @@ std::string makeApplyRuleFunction( args.push_back(arg); types.push_back(arg->getType()); } - llvm::Type *blockType = getValueType({SortCategory::Symbol, 0}, Module); + llvm::Type *blockType = getValueType(ValueTypes::Symbol(false), Module); llvm::Function *step = getOrInsertFunction( Module, "step_" + std::to_string(axiom->getOrdinal()), llvm::FunctionType::get(blockType, types, false)); @@ -1390,7 +1395,7 @@ std::string makeSideConditionFunction( } llvm::Type *getArgType(ValueType cat, llvm::Module *mod) { - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Bool: case SortCategory::MInt: case SortCategory::Map: @@ -1415,7 +1420,7 @@ llvm::Type *getArgType(ValueType cat, llvm::Module *mod) { } bool isCollectionSort(ValueType cat) { - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: diff --git a/lib/codegen/Debug.cpp b/lib/codegen/Debug.cpp index d3fba88cb..21b140a6c 100644 --- a/lib/codegen/Debug.cpp +++ b/lib/codegen/Debug.cpp @@ -169,7 +169,7 @@ llvm::DIType *getDebugType(ValueType type, std::string typeName) { if (types[typeName]) { return types[typeName]; } - switch (type.cat) { + switch (type.cat()) { case SortCategory::Map: map = getPointerDebugType(getForwardDecl(MAP_STRUCT), typeName); types[typeName] = map; @@ -204,7 +204,8 @@ llvm::DIType *getDebugType(ValueType type, std::string typeName) { return boolean; case SortCategory::MInt: mint = Dbg->createBasicType( - typeName, type.bits, llvm::dwarf::DW_ATE_unsigned); + typeName, std::get(type).bits, + llvm::dwarf::DW_ATE_unsigned); types[typeName] = mint; return mint; case SortCategory::Symbol: @@ -214,6 +215,7 @@ llvm::DIType *getDebugType(ValueType type, std::string typeName) { return symbol; case SortCategory::Uncomputed: abort(); } + abort(); } llvm::DIType *getIntDebugType(void) { diff --git a/lib/codegen/Decision.cpp b/lib/codegen/Decision.cpp index 1d35d314a..5793138d6 100644 --- a/lib/codegen/Decision.cpp +++ b/lib/codegen/Decision.cpp @@ -47,7 +47,7 @@ void Decision::operator()(DecisionNode *entry) { if (entry == FailNode::get()) { if (FailPattern) { llvm::Value *val = load(std::make_pair( - "_1", getValueType({SortCategory::Symbol, 0}, Module))); + "_1", getValueType(ValueTypes::Symbol(false), Module))); FailSubject->addIncoming( new llvm::BitCastInst( val, llvm::Type::getInt8PtrTy(Ctx), "", CurrentBlock), @@ -255,7 +255,7 @@ void SwitchNode::codegen(Decision *d) { _case.getConstructor()->getArguments()[offset].get()) ->getCategory(d->Definition); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -493,7 +493,7 @@ void IterNextNode::codegen(Decision *d) { llvm::Value *arg = d->load(std::make_pair(iterator, iteratorType)); llvm::FunctionType *funcType = llvm::FunctionType::get( - getValueType({SortCategory::Symbol, 0}, d->Module), {arg->getType()}, + getValueType(ValueTypes::Symbol(false), d->Module), {arg->getType()}, false); llvm::Function *func = getOrInsertFunction(d->Module, hookName, funcType); auto Call = llvm::CallInst::Create( @@ -559,7 +559,7 @@ llvm::Value *Decision::getTag(llvm::Value *val) { auto res = llvm::CallInst::Create( getOrInsertFunction( Module, "getTag", llvm::Type::getInt32Ty(Ctx), - getValueType({SortCategory::Symbol, 0}, Module)), + getValueType(ValueTypes::Symbol(false), Module)), val, "tag", CurrentBlock); setDebugLoc(res); return res; @@ -679,7 +679,7 @@ void makeEvalOrAnywhereFunction( std::ostringstream Out; sort->print(Out); debugArgs.push_back(getDebugType(cat, Out.str())); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -751,7 +751,7 @@ void abortWhenStuck( llvm::ConstantInt::get( llvm::Type::getInt64Ty(Ctx), ((uint64_t)symbol->getTag() << 32 | 1)), - getValueType({SortCategory::Symbol, 0}, Module)); + getValueType(ValueTypes::Symbol(false), Module)); } else { llvm::Value *BlockHeader = getBlockHeader(Module, d, symbol, BlockType); llvm::Value *Block = allocateTerm(BlockType, CurrentBlock); @@ -864,7 +864,7 @@ std::pair, llvm::BasicBlock *> stepFunctionHeader( std::vector ptrTypes; std::vector roots; for (auto type : types) { - switch (type.cat) { + switch (type.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -909,7 +909,7 @@ std::pair, llvm::BasicBlock *> stepFunctionHeader( std::vector elements; i = 0; for (auto cat : types) { - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -925,7 +925,7 @@ std::pair, llvm::BasicBlock *> stepFunctionHeader( llvm::Type::getInt64Ty(module->getContext()), i++ * 8), llvm::ConstantInt::get( llvm::Type::getInt16Ty(module->getContext()), - (int)cat.cat + cat.bits))); + static_cast>(cat.cat())))); break; case SortCategory::Bool: case SortCategory::MInt: break; @@ -973,7 +973,7 @@ std::pair, llvm::BasicBlock *> stepFunctionHeader( unsigned rootIdx = 0; std::vector results; for (auto type : types) { - switch (type.cat) { + switch (type.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -993,9 +993,9 @@ std::pair, llvm::BasicBlock *> stepFunctionHeader( void makeStepFunction( KOREDefinition *definition, llvm::Module *module, DecisionNode *dt, bool search) { - auto blockType = getValueType({SortCategory::Symbol, 0}, module); + auto blockType = getValueType(ValueTypes::Symbol(false), module); auto debugType - = getDebugType({SortCategory::Symbol, 0}, "SortGeneratedTopCell{}"); + = getDebugType(ValueTypes::Symbol(false), "SortGeneratedTopCell{}"); llvm::FunctionType *funcType; std::string name; if (search) { @@ -1044,16 +1044,16 @@ void makeStepFunction( block); } initDebugParam( - matchFunc, 0, "subject", {SortCategory::Symbol, 0}, + matchFunc, 0, "subject", ValueTypes::Symbol(false), "SortGeneratedTopCell{}"); llvm::BranchInst::Create(stuck, pre_stuck); auto result = stepFunctionHeader( - 0, module, definition, block, stuck, {val}, {{SortCategory::Symbol, 0}}); + 0, module, definition, block, stuck, {val}, {ValueTypes::Symbol(false)}); auto collectedVal = result.first[0]; collectedVal->setName("_1"); Decision codegen( definition, result.second, fail, jump, choiceBuffer, choiceDepth, module, - {SortCategory::Symbol, 0}, nullptr, nullptr, nullptr, HasSearchResults); + ValueTypes::Symbol(false), nullptr, nullptr, nullptr, HasSearchResults); codegen.store( std::make_pair(collectedVal->getName().str(), collectedVal->getType()), collectedVal); @@ -1073,7 +1073,7 @@ void makeStepFunction( void makeMatchReasonFunctionWrapper( KOREDefinition *definition, llvm::Module *module, KOREAxiomDeclaration *axiom, std::string name) { - auto blockType = getValueType({SortCategory::Symbol, 0}, module); + auto blockType = getValueType(ValueTypes::Symbol(false), module); llvm::FunctionType *funcType = llvm::FunctionType::get( llvm::Type::getVoidTy(module->getContext()), {blockType}, false); std::string wrapperName = "match_" + std::to_string(axiom->getOrdinal()); @@ -1084,7 +1084,7 @@ void makeMatchReasonFunctionWrapper( debugName = axiom->getStringAttribute("label") + "_tailcc_" + ".match"; } auto debugType - = getDebugType({SortCategory::Symbol, 0}, "SortGeneratedTopCell{}"); + = getDebugType(ValueTypes::Symbol(false), "SortGeneratedTopCell{}"); resetDebugLoc(); initDebugFunction( debugName, debugName, @@ -1104,7 +1104,7 @@ void makeMatchReasonFunctionWrapper( void makeMatchReasonFunction( KOREDefinition *definition, llvm::Module *module, KOREAxiomDeclaration *axiom, DecisionNode *dt) { - auto blockType = getValueType({SortCategory::Symbol, 0}, module); + auto blockType = getValueType(ValueTypes::Symbol(false), module); llvm::FunctionType *funcType = llvm::FunctionType::get( llvm::Type::getVoidTy(module->getContext()), {blockType}, false); std::string name = "intern_match_" + std::to_string(axiom->getOrdinal()); @@ -1114,7 +1114,7 @@ void makeMatchReasonFunction( debugName = axiom->getStringAttribute("label") + ".match"; } auto debugType - = getDebugType({SortCategory::Symbol, 0}, "SortGeneratedTopCell{}"); + = getDebugType(ValueTypes::Symbol(false), "SortGeneratedTopCell{}"); resetDebugLoc(); initDebugFunction( debugName, debugName, @@ -1152,13 +1152,13 @@ void makeMatchReasonFunction( dt, module, block, pre_stuck, fail, &choiceBuffer, &choiceDepth, &jump); initDebugParam( - matchFunc, 0, "subject", {SortCategory::Symbol, 0}, + matchFunc, 0, "subject", ValueTypes::Symbol(false), "SortGeneratedTopCell{}"); llvm::BranchInst::Create(stuck, pre_stuck); val->setName("_1"); Decision codegen( definition, block, fail, jump, choiceBuffer, choiceDepth, module, - {SortCategory::Symbol, 0}, FailSubject, FailPattern, FailSort, nullptr); + ValueTypes::Symbol(false), FailSubject, FailPattern, FailSort, nullptr); codegen.store(std::make_pair(val->getName().str(), val->getType()), val); llvm::ReturnInst::Create(module->getContext(), stuck); @@ -1200,7 +1200,7 @@ KOREPattern *makePartialTerm( void makeStepFunction( KOREAxiomDeclaration *axiom, KOREDefinition *definition, llvm::Module *module, PartialStep res) { - auto blockType = getValueType({SortCategory::Symbol, 0}, module); + auto blockType = getValueType(ValueTypes::Symbol(false), module); std::vector argTypes; std::vector debugTypes; for (auto res : res.residuals) { @@ -1210,7 +1210,7 @@ void makeStepFunction( std::ostringstream Out; argSort->print(Out); debugTypes.push_back(getDebugType(cat, Out.str())); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -1222,7 +1222,7 @@ void makeStepFunction( } } auto blockDebugType - = getDebugType({SortCategory::Symbol, 0}, "SortGeneratedTopCell{}"); + = getDebugType(ValueTypes::Symbol(false), "SortGeneratedTopCell{}"); llvm::FunctionType *funcType = llvm::FunctionType::get(blockType, argTypes, false); std::string name = "step_" + std::to_string(axiom->getOrdinal()); @@ -1274,7 +1274,7 @@ void makeStepFunction( i = 0; Decision codegen( definition, header.second, fail, jump, choiceBuffer, choiceDepth, module, - {SortCategory::Symbol, 0}, nullptr, nullptr, nullptr, nullptr); + ValueTypes::Symbol(false), nullptr, nullptr, nullptr, nullptr); for (auto val : header.first) { val->setName(res.residuals[i].occurrence.substr(0, max_name_length)); codegen.store(std::make_pair(val->getName().str(), val->getType()), val); diff --git a/lib/codegen/DecisionParser.cpp b/lib/codegen/DecisionParser.cpp index 10e948059..8be7ace41 100644 --- a/lib/codegen/DecisionParser.cpp +++ b/lib/codegen/DecisionParser.cpp @@ -26,7 +26,7 @@ class DTPreprocessor { private: std::map uniqueNodes; const std::map &syms; - const std::map> &hookedSorts; + const std::map> &hookedSorts; KORESymbol *dv; yaml_document_t *doc; llvm::Module *mod; @@ -99,7 +99,7 @@ class DTPreprocessor { DTPreprocessor( const std::map &syms, - const std::map> &hookedSorts, + const std::map> &hookedSorts, llvm::Module *mod, yaml_document_t *doc) : syms(syms) , hookedSorts(hookedSorts) @@ -159,14 +159,13 @@ class DTPreprocessor { } else { name = str(o); } - std::string hook = str(get(node, "hook")); - uses.push_back(std::make_pair( - name, getParamType(KORECompositeSort::getCategory(hook), mod))); + ValueType hook = KORECompositeSort::getCategory(str(get(node, "hook"))); + uses.push_back(std::make_pair(name, getParamType(hook, mod))); return KOREVariablePattern::Create(name, hookedSorts.at(hook)); } else if (get(node, "literal")) { auto sym = KORESymbol::Create("\\dv"); auto hook = str(get(node, "hook")); - auto sort = hookedSorts.at(hook); + auto sort = hookedSorts.at(KORECompositeSort::getCategory(hook)); auto lit = get(node, "literal"); auto val = str(lit); if (hook == "BOOL.Bool") { @@ -362,7 +361,7 @@ class DTPreprocessor { DecisionNode *parseYamlDecisionTreeFromString( llvm::Module *mod, std::string yaml, const std::map &syms, - const std::map> &hookedSorts) { + const std::map> &hookedSorts) { yaml_parser_t parser; yaml_document_t doc; yaml_parser_initialize(&parser); @@ -383,7 +382,7 @@ DecisionNode *parseYamlDecisionTreeFromString( DecisionNode *parseYamlDecisionTree( llvm::Module *mod, std::string filename, const std::map &syms, - const std::map> &hookedSorts) { + const std::map> &hookedSorts) { yaml_parser_t parser; yaml_document_t doc; yaml_parser_initialize(&parser); @@ -405,7 +404,7 @@ DecisionNode *parseYamlDecisionTree( PartialStep parseYamlSpecialDecisionTree( llvm::Module *mod, std::string filename, const std::map &syms, - const std::map> &hookedSorts) { + const std::map> &hookedSorts) { yaml_parser_t parser; yaml_document_t doc; yaml_parser_initialize(&parser); diff --git a/lib/codegen/EmitConfigParser.cpp b/lib/codegen/EmitConfigParser.cpp index 0ef1e4357..fbf83fa2c 100644 --- a/lib/codegen/EmitConfigParser.cpp +++ b/lib/codegen/EmitConfigParser.cpp @@ -320,7 +320,7 @@ static llvm::Value *getArgValue( CaseBlock); llvm::Value *arg = new llvm::LoadInst(i8_ptr_ty, addr, "", CaseBlock); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Bool: case SortCategory::MInt: { auto val_ty = getValueType(cat, mod); @@ -374,7 +374,7 @@ static std::pair getEval( llvm::Value *retval; ValueType cat = dynamic_cast(symbol->getSort().get()) ->getCategory(def); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Int: case SortCategory::Float: case SortCategory::StringBuffer: @@ -558,9 +558,9 @@ static void emitGetToken(KOREDefinition *definition, llvm::Module *module) { auto sort = KORECompositeSort::Create(name); ValueType cat = sort->getCategory(definition); - if ((cat.cat == SortCategory::Symbol - && sort->getHook(definition) != "BYTES.Bytes") - || cat.cat == SortCategory::Variable) { + if ((cat.cat() == SortCategory::Symbol + && !std::get_if(&cat)->isBytes) + || cat.cat() == SortCategory::Variable) { continue; } CurrentBlock->insertInto(func); @@ -583,7 +583,7 @@ static void emitGetToken(KOREDefinition *definition, llvm::Module *module) { auto FalseBlock = llvm::BasicBlock::Create(Ctx, ""); auto CaseBlock = llvm::BasicBlock::Create(Ctx, name, func); llvm::BranchInst::Create(CaseBlock, FalseBlock, icmp, CurrentBlock); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Map: case SortCategory::RangeMap: case SortCategory::List: @@ -746,31 +746,30 @@ makePackedVisitorStructureType(llvm::LLVMContext &Ctx, llvm::Module *module) { if (types.find(&Ctx) == types.end()) { auto elementTypes = std::vector{ {makeVisitorType( - Ctx, file, getValueType({SortCategory::Symbol, 0}, module), 1, 1), + Ctx, file, getValueType(ValueTypes::Symbol(false), module), 1, 1), makeVisitorType( Ctx, file, llvm::PointerType::getUnqual( - getValueType({SortCategory::Map, 0}, module)), + getValueType(ValueTypes::Map(), module)), 3, 0), makeVisitorType( Ctx, file, llvm::PointerType::getUnqual( - getValueType({SortCategory::List, 0}, module)), + getValueType(ValueTypes::List(), module)), 3, 0), makeVisitorType( Ctx, file, llvm::PointerType::getUnqual( - getValueType({SortCategory::Set, 0}, module)), + getValueType(ValueTypes::Set(), module)), 3, 0), makeVisitorType( - Ctx, file, getValueType({SortCategory::Int, 0}, module), 1, 0), + Ctx, file, getValueType(ValueTypes::Int(), module), 1, 0), makeVisitorType( - Ctx, file, getValueType({SortCategory::Float, 0}, module), 1, 0), + Ctx, file, getValueType(ValueTypes::Float(), module), 1, 0), makeVisitorType( - Ctx, file, getValueType({SortCategory::Bool, 0}, module), 1, 0), + Ctx, file, getValueType(ValueTypes::Bool(), module), 1, 0), makeVisitorType( - Ctx, file, getValueType({SortCategory::StringBuffer, 0}, module), - 1, 0), + Ctx, file, getValueType(ValueTypes::StringBuffer(), module), 1, 0), llvm::PointerType::getUnqual(llvm::FunctionType::get( llvm::Type::getVoidTy(Ctx), {file, llvm::Type::getInt64PtrTy(Ctx), llvm::Type::getInt64Ty(Ctx), @@ -782,7 +781,7 @@ makePackedVisitorStructureType(llvm::LLVMContext &Ctx, llvm::Module *module) { makeVisitorType( Ctx, file, llvm::PointerType::getUnqual( - getValueType({SortCategory::RangeMap, 0}, module)), + getValueType(ValueTypes::RangeMap(), module)), 3, 0)}}; auto structTy = llvm::StructType::create(Ctx, elementTypes, name); @@ -995,7 +994,7 @@ static void getVisitor( } llvm::Constant *CharPtr = llvm::ConstantExpr::getInBoundsGetElementPtr( Str->getType(), global, indices); - switch (cat.cat) { + switch (cat.cat()) { case SortCategory::Variable: case SortCategory::Symbol: llvm::CallInst::Create( @@ -1007,7 +1006,7 @@ static void getVisitor( callbacks.at(0), {func->arg_begin() + 1, Child, CharPtr, llvm::ConstantInt::get( - llvm::Type::getInt1Ty(Ctx), cat.cat == SortCategory::Variable), + llvm::Type::getInt1Ty(Ctx), cat.cat() == SortCategory::Variable), state_ptr}, "", CaseBlock); break; @@ -1054,9 +1053,9 @@ static void getVisitor( case SortCategory::MInt: { llvm::Value *mint = new llvm::LoadInst( getArgType(cat, module), ChildPtr, "mint", CaseBlock); - size_t nwords = (cat.bits + 63) / 64; - auto nbits - = llvm::ConstantInt::get(llvm::Type::getInt64Ty(Ctx), cat.bits); + uint64_t bits = std::get_if(&cat)->bits; + size_t nwords = (bits + 63) / 64; + auto nbits = llvm::ConstantInt::get(llvm::Type::getInt64Ty(Ctx), bits); auto fnType = llvm::FunctionType::get( llvm::Type::getVoidTy(Ctx), {file, llvm::Type::getInt64PtrTy(Ctx), llvm::Type::getInt64Ty(Ctx), @@ -1076,7 +1075,7 @@ static void getVisitor( CaseBlock, "koreAllocAlwaysGC"); if (nwords == 1) { llvm::Value *Word; - if (cat.bits == 64) { + if (bits == 64) { Word = mint; } else { Word = new llvm::ZExtInst( @@ -1153,10 +1152,14 @@ static llvm::Constant *getLayoutData( ValueType cat = dynamic_cast(sort.get())->getCategory(def); auto offset = llvm::ConstantExpr::getOffsetOf(BlockType, i++); + uint64_t bits + = static_cast>(cat.cat()); + if (auto *mint = std::get_if(&cat)) { + bits += mint->bits; + } elements.push_back(llvm::ConstantStruct::get( getTypeByName(module, LAYOUTITEM_STRUCT), offset, - llvm::ConstantInt::get( - llvm::Type::getInt16Ty(Ctx), (int)cat.cat + cat.bits))); + llvm::ConstantInt::get(llvm::Type::getInt16Ty(Ctx), bits))); } auto Arr = llvm::ConstantArray::get( llvm::ArrayType::get(getTypeByName(module, LAYOUTITEM_STRUCT), len),