diff --git a/include/kllvm/ast/AST.h b/include/kllvm/ast/AST.h index f3ad3d59d..5094e03b1 100644 --- a/include/kllvm/ast/AST.h +++ b/include/kllvm/ast/AST.h @@ -428,6 +428,21 @@ class KOREPattern : public std::enable_shared_from_this { */ virtual sptr desugarAssociative() = 0; + /** + * Abstracts the common pattern of checking whether a composite pattern has a + * particular top-level constructor and arity. For example: + * + * c(A, B) -> matchesShape(c, 2) == true + * c() -> matchesShape(c, 2) == false + * c() -> matchesShape(b, 0) == false + * + * For instances of KORECompositePattern, return this if the constructor and + * arity and match. For all other subclasses, and when they do not match, + * return nullptr; + */ + virtual sptr + matchesShape(std::string const &constructor, size_t arity) = 0; + friend KORECompositePattern; private: @@ -500,6 +515,11 @@ class KOREVariablePattern : public KOREPattern { return shared_from_this(); } + sptr + matchesShape(std::string const &constructor, size_t arity) override { + return nullptr; + } + virtual bool matches( substitution &subst, SubsortMap const &, SymbolMap const &, sptr subject) override; @@ -584,6 +604,9 @@ class KORECompositePattern : public KOREPattern { substitution &, SubsortMap const &, SymbolMap const &, sptr) override; + sptr + matchesShape(std::string const &constructor, size_t arity) override; + private: virtual sptr expandMacros( SubsortMap const &, SymbolMap const &, @@ -654,6 +677,11 @@ class KOREStringPattern : public KOREPattern { substitution &, SubsortMap const &, SymbolMap const &, sptr subject) override; + sptr + matchesShape(std::string const &constructor, size_t arity) override { + return nullptr; + } + private: virtual sptr expandMacros( SubsortMap const &, SymbolMap const &, diff --git a/lib/ast/AST.cpp b/lib/ast/AST.cpp index 22ad2cd4d..cbddef87d 100644 --- a/lib/ast/AST.cpp +++ b/lib/ast/AST.cpp @@ -406,6 +406,16 @@ sptr KORECompositePattern::expandAliases(KOREDefinition *def) { return ptr; } +sptr KORECompositePattern::matchesShape( + std::string const &constructor, size_t arity) { + if (getConstructor()->getName() == constructor + && getArguments().size() == arity) { + return std::dynamic_pointer_cast(shared_from_this()); + } + + return nullptr; +} + static int indent = 0; static bool atNewLine = true; @@ -1375,120 +1385,87 @@ getPatterns(KOREPattern *pat, std::vector &result) { * lhs(\implies(\equals(_, _), \equals(_(Xs), _))) = Xs */ std::vector KOREAxiomDeclaration::getLeftHandSide() const { - if (auto top = dynamic_cast(pattern.get())) { - if (top->getConstructor()->getName() == "\\rewrites" - && top->getArguments().size() == 2) { - if (auto andPattern = dynamic_cast( - top->getArguments()[0].get())) { - if (andPattern->getConstructor()->getName() == "\\and" - && andPattern->getArguments().size() == 2) { - if (auto firstChild = dynamic_cast( - andPattern->getArguments()[0].get())) { - if (firstChild->getConstructor()->getName() == "\\equals" - && firstChild->getArguments().size() == 2) { - return {andPattern->getArguments()[1].get()}; - } else if ( - firstChild->getConstructor()->getName() == "\\top" - && firstChild->getArguments().size() == 0) { - return {andPattern->getArguments()[1].get()}; - } else { - if (auto secondChild = dynamic_cast( - andPattern->getArguments()[1].get())) { - if (secondChild->getConstructor()->getName() == "\\equals" - && secondChild->getArguments().size() == 2) { - return {firstChild}; - } else if ( - secondChild->getConstructor()->getName() == "\\top" - && secondChild->getArguments().size() == 0) { - return {firstChild}; - } else { - if (firstChild->getConstructor()->getName() == "\\not" - && firstChild->getArguments().size() == 1 - && secondChild->getConstructor()->getName() == "\\and" - && secondChild->getArguments().size() == 2) { - if (auto inner = dynamic_cast( - secondChild->getArguments()[0].get())) { - if (inner->getConstructor()->getName() == "\\equals" - && inner->getArguments().size() == 2) { - return {secondChild->getArguments()[1].get()}; - } else if ( - inner->getConstructor()->getName() == "\\top" - && inner->getArguments().size() == 0) { - return {secondChild->getArguments()[1].get()}; - } - } - } - } - } - } - } + if (auto top = pattern->matchesShape("\\rewrites", 2)) { + if (auto andPattern = top->getArguments()[0]->matchesShape("\\and", 2)) { + auto &firstChild = andPattern->getArguments()[0]; + auto &secondChild = andPattern->getArguments()[1]; + + if (firstChild->matchesShape("\\equals", 2) + || firstChild->matchesShape("\\top", 0)) { + return {secondChild.get()}; + } + + if (secondChild->matchesShape("\\equals", 2) + || secondChild->matchesShape("\\top", 0)) { + return {firstChild.get()}; + } + + if (auto secondAnd = secondChild->matchesShape("\\and", 2); + secondAnd && firstChild->matchesShape("\\not", 1)) { + auto &inner = secondAnd->getArguments()[0]; + + if (inner->matchesShape("\\equals", 2) + || inner->matchesShape("\\top", 0)) { + return {secondAnd->getArguments()[1].get()}; } } - } else if ( - top->getConstructor()->getName() == "\\equals" - && top->getArguments().size() == 2) { - if (auto firstChild = dynamic_cast( - top->getArguments()[0].get())) { - std::vector result; - for (auto &sptr : firstChild->getArguments()) { - result.push_back(sptr.get()); + } + } + + if (auto top = pattern->matchesShape("\\equals", 2)) { + if (auto firstChild = std::dynamic_pointer_cast( + top->getArguments()[0])) { + auto result = std::vector{}; + + for (auto &sptr : firstChild->getArguments()) { + result.push_back(sptr.get()); + } + + return result; + } + } + + if (auto top = pattern->matchesShape("\\implies", 2)) { + auto &impliesLhs = top->getArguments()[0]; + auto &impliesRhs = top->getArguments()[1]; + + if (auto firstChild = impliesLhs->matchesShape("\\and", 2)) { + auto lhsAnd = firstChild; + + if (auto innerFirst + = firstChild->getArguments()[0]->matchesShape("\\not", 1)) { + if (auto innerSecond = std::dynamic_pointer_cast( + firstChild->getArguments()[1])) { + lhsAnd = innerSecond; } - return result; } - } else if ( - top->getConstructor()->getName() == "\\implies" - && top->getArguments().size() == 2) { - if (auto firstChild = dynamic_cast( - top->getArguments()[0].get())) { - if (firstChild->getConstructor()->getName() == "\\and" - && firstChild->getArguments().size() == 2) { - auto lhsAnd = firstChild; - if (auto innerFirst = dynamic_cast( - firstChild->getArguments()[0].get())) { - if (innerFirst->getConstructor()->getName() == "\\not" - && innerFirst->getArguments().size() == 1) { - if (auto innerSecond = dynamic_cast( - firstChild->getArguments()[1].get())) { - lhsAnd = innerSecond; - } - } - } - if (auto sideCondition = dynamic_cast( - lhsAnd->getArguments()[0].get())) { - if (sideCondition->getConstructor()->getName() == "\\equals" - && sideCondition->getArguments().size() == 2) { - std::vector result; - return getPatterns(lhsAnd->getArguments()[1].get(), result); - } else if ( - sideCondition->getConstructor()->getName() == "\\top" - && sideCondition->getArguments().size() == 0) { - std::vector result; - return getPatterns(lhsAnd->getArguments()[1].get(), result); - } - } - } else if ( - (firstChild->getConstructor()->getName() == "\\top" - && firstChild->getArguments().size() == 0) - || (firstChild->getConstructor()->getName() == "\\equals" - && firstChild->getArguments().size() == 2)) { - if (auto secondChild = dynamic_cast( - top->getArguments()[1].get())) { - if (secondChild->getConstructor()->getName() == "\\equals" - && secondChild->getArguments().size() == 2) { - if (auto lhs = dynamic_cast( - secondChild->getArguments()[0].get())) { - std::vector result; - for (auto &sptr : lhs->getArguments()) { - result.push_back(sptr.get()); - } - return result; - } - } + + auto &sideCondition = lhsAnd->getArguments()[0]; + if (sideCondition->matchesShape("\\equals", 2) + || sideCondition->matchesShape("\\top", 0)) { + + std::vector result; + return getPatterns(lhsAnd->getArguments()[1].get(), result); + } + } + + if (impliesLhs->matchesShape("\\top", 0) + || impliesLhs->matchesShape("\\equals", 2)) { + if (auto secondChild = impliesRhs->matchesShape("\\equals", 2)) { + if (auto lhs = std::dynamic_pointer_cast( + secondChild->getArguments()[0])) { + + auto result = std::vector{}; + for (auto &sptr : lhs->getArguments()) { + result.push_back(sptr.get()); } + + return result; } } } } + assert(false && "could not compute left hand side of axiom"); abort(); }