Skip to content

Commit

Permalink
Merge branch 'main' into sam/performance-mx
Browse files Browse the repository at this point in the history
  • Loading branch information
goodlyrottenapple authored Apr 4, 2024
2 parents 87aa323 + 34305b1 commit 20650ab
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 38 deletions.
55 changes: 27 additions & 28 deletions library/Booster/Pattern/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import Data.Data (Data)
import Data.Functor.Foldable
import Data.Hashable (Hashable)
import Data.Hashable qualified as Hashable
import Data.List as List (foldl1', sort)
import Data.List as List (foldl', foldl1', sort)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Text qualified as Text
Expand Down Expand Up @@ -166,6 +166,10 @@ type instance Base Term = TermF
instance Recursive Term where
project (Term _ t) = t

-- | Sort and de duplicate a list
sortAndDeduplicate :: Ord a => [a] -> [a]
sortAndDeduplicate = Set.toAscList . Set.fromList

getAttributes :: Term -> TermAttributes
getAttributes (Term a _) = a

Expand Down Expand Up @@ -584,29 +588,24 @@ pattern KMap def keyVals rest <- Term _ (KMapF def keyVals rest)
([], Nothing) -> mempty
([], Just s) -> getAttributes s
(_ : _, Nothing) -> foldl1' (<>) $ concatMap (\(k, v) -> [getAttributes k, getAttributes v]) keyVals
(_ : _, Just r) -> foldr (<>) (getAttributes r) $ concatMap (\(k, v) -> [getAttributes k, getAttributes v]) keyVals
(_ : _, Just r) ->
foldl' (<>) (getAttributes r) $ concatMap (\(k, v) -> [getAttributes k, getAttributes v]) keyVals
(keyVals', rest') = case rest of
Just (KMap def' kvs r) | def' == def -> (kvs, r)
r -> ([], r)
newKeyVals = sortAndDeduplicate $ keyVals ++ keyVals'
newRest = rest'
in Term
argAttributes
{ isEvaluated =
-- Constructors and injections are evaluated if their arguments are.
-- Function calls are not evaluated.
argAttributes.isEvaluated
, hash =
{ hash =
Hashable.hash
( "KMap" :: ByteString
, def
, map (\(k, v) -> (hash $ getAttributes k, hash $ getAttributes v)) keyVals
, hash . getAttributes <$> rest
, map (\(k, v) -> (hash $ getAttributes k, hash $ getAttributes v)) newKeyVals
, hash . getAttributes <$> newRest
)
, isConstructorLike =
argAttributes.isConstructorLike
, canBeEvaluated =
argAttributes.canBeEvaluated
}
$ KMapF def (Set.toList $ Set.fromList $ keyVals ++ keyVals') rest'
$ KMapF def newKeyVals newRest

pattern KList :: KListDefinition -> [Term] -> Maybe (Term, [Term]) -> Term
pattern KList def heads rest <- Term _ (KListF def heads rest)
Expand All @@ -619,7 +618,7 @@ pattern KList def heads rest <- Term _ (KListF def heads rest)
(nonEmpty, Nothing) ->
foldl1' (<>) $ map getAttributes nonEmpty
(_, Just (m, tails)) ->
foldr ((<>) . getAttributes) (getAttributes m) $ heads <> tails
foldl' (<>) (getAttributes m) . map getAttributes $ heads <> tails
(newHeads, newRest) = case rest of
Just (KList def' heads' rest', tails)
| def' /= def ->
Expand All @@ -636,9 +635,9 @@ pattern KList def heads rest <- Term _ (KListF def heads rest)
Hashable.hash
( "KList" :: ByteString
, def
, map (hash . getAttributes) heads
, fmap (hash . getAttributes . fst) rest
, fmap (map (hash . getAttributes) . snd) rest
, map (hash . getAttributes) newHeads
, fmap (hash . getAttributes . fst) newRest
, fmap (map (hash . getAttributes) . snd) newRest
)
}
$ KListF def newHeads newRest
Expand All @@ -654,22 +653,22 @@ pattern KSet def elements rest <- Term _ (KSetF def elements rest)
| Nothing <- rest =
foldl1' (<>) $ map getAttributes elements
| Just r <- rest =
foldr ((<>) . getAttributes) (getAttributes r) elements
(newElements, newRest) = case rest of
Just (KSet def' elements' rest')
| def /= def' ->
error $ "Inconsistent set definition " <> show (def, def')
| otherwise ->
(Set.toList . Set.fromList $ elements <> elements', rest')
other -> (elements, other)
foldl' (<>) (getAttributes r) . map getAttributes $ elements
(elements', rest') = case rest of
Just (KSet def' es r)
| def /= def' -> error $ "Inconsistent set definition " <> show (def, def')
| otherwise -> (es, r)
other -> ([], other)
newElements = sortAndDeduplicate $ elements <> elements'
newRest = rest'
in Term
argAttributes
{ hash =
Hashable.hash
( "KSet" :: ByteString
, def
, map (hash . getAttributes) elements
, fmap (hash . getAttributes) rest
, map (hash . getAttributes) newElements
, fmap (hash . getAttributes) newRest
)
}
$ KSetF def newElements newRest
Expand Down
89 changes: 83 additions & 6 deletions test/llvm-integration/LLVM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ module Main (
displayTestDef,
) where

import Control.Monad (unless, when)
import Control.Monad (forM_, unless, when)
import Control.Monad.Trans.Except (runExcept)
import Data.ByteString.Char8 (ByteString)
import Data.ByteString.Char8 qualified as BS
import Data.Char (toLower)
import Data.Int (Int64)
import Data.List (isInfixOf)
import Data.List (foldl1', isInfixOf, nub)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
Expand Down Expand Up @@ -89,6 +89,10 @@ llvmSpec =
it "should leave literal byte arrays as they are" $
hedgehog . propertyTest . byteArrayProp

describe "LLVM INT simplification" $ do
it "should leave naked domain values as they are" $
hedgehog . propertyTest . intProp

describe "LLVM String handling" $
it "should work with latin-1strings" $
hedgehog . propertyTest . latin1Prop
Expand All @@ -97,6 +101,10 @@ llvmSpec =
it "should correct sort injections in non KItem maps" $
hedgehog . propertyTest . mapKItemInjProp

describe "internalised set tests" $
it "should leave concrete sets unchanged" $
hedgehog . propertyTest . setProp

--------------------------------------------------
-- individual hedgehog property tests and helpers

Expand Down Expand Up @@ -132,6 +140,12 @@ byteArrayProp api = property $ do
res' <- LLVM.simplifyTerm api testDef (bytesTerm ba') bytesSort
res' === Right (bytesTerm ba')

intProp :: LLVM.API -> Property
intProp api = property $ do
i <- forAll $ Gen.int (Range.linear 0 1024)
res <- LLVM.simplifyTerm api testDef (intTerm i) intSort
res === Right (intTerm i)

-- Round-trip test passing syntactic strings through the simplifier
-- and back. latin-1 characters should be left as they are (treated as
-- bytes internally). UTF-8 code points beyond latin-1 are forbidden.
Expand Down Expand Up @@ -195,6 +209,48 @@ mapKItemInjProp api = property $ do
[intTerm i]
]

setProp :: LLVM.API -> Property
setProp api = property $ do
forM_ [1 .. 10] $ \n -> do
xs <-
forAll $
Gen.filter (\xs -> xs == nub xs) $
Gen.list (Range.singleton n) $
Gen.int (Range.linear 0 1024)
let setTerm = makeKSetNoRest xs
res <- LLVM.simplifyTerm api testDef setTerm (SortApp "SortSet" [])
res === Right (setAsConcat . map wrapIntTerm $ xs)
where
makeKSetNoRest :: [Int] -> Term
makeKSetNoRest xs =
KSet
sortSetKSet
(map wrapIntTerm xs)
Nothing

singletonSet v =
SymbolApplication
(defSymbols Map.! sortSetKSet.symbolNames.elementSymbolName)
[]
[v]

setAsConcat =
foldl1'
( \x y ->
SymbolApplication
(defSymbols Map.! sortSetKSet.symbolNames.concatSymbolName)
[]
[x, y]
)
. map singletonSet

wrapIntTerm :: Int -> Term
wrapIntTerm i =
SymbolApplication
(defSymbols Map.! "inj")
[intSort, kItemSort]
[intTerm i]

------------------------------------------------------------

runKompile :: IO ()
Expand All @@ -217,11 +273,12 @@ loadAPI = Internal.withDLib dlPath Internal.mkAPI
------------------------------------------------------------
-- term construction

boolSort, intSort, bytesSort, stringSort :: Sort
boolSort, intSort, bytesSort, stringSort, kItemSort :: Sort
boolSort = SortApp "SortBool" []
intSort = SortApp "SortInt" []
bytesSort = SortApp "SortBytes" []
stringSort = SortApp "SortString" []
kItemSort = SortApp "SortKItem" []

boolTerm :: Bool -> Term
boolTerm = DomainValue boolSort . BS.pack . map toLower . show
Expand Down Expand Up @@ -315,6 +372,19 @@ sortMapKmap =
, mapSortName = "SortMap"
}

sortSetKSet :: KSetDefinition
sortSetKSet =
KListDefinition
{ symbolNames =
KCollectionSymbolNames
{ unitSymbolName = "Lbl'Stop'Set"
, elementSymbolName = "LblSetItem"
, concatSymbolName = "Lbl'Unds'Set'Unds'"
}
, elementSortName = "SortKItem"
, listSortName = "SortSet"
}

sortListKList :: KListDefinition
sortListKList =
KListDefinition
Expand Down Expand Up @@ -430,6 +500,13 @@ defSorts =
, Set.fromList ["SortList"]
)
)
,
( "SortSet"
,
( SortAttributes{collectionAttributes = Just (sortSetKSet.symbolNames, KSetTag), argCount = 0}
, Set.fromList ["SortSet"]
)
)
,
( "SortMap"
,
Expand Down Expand Up @@ -635,7 +712,7 @@ defSymbols =
, resultSort = SortApp "SortSet" []
, attributes =
SymbolAttributes
{ collectionMetadata = Nothing
{ collectionMetadata = Just $ KSetMeta sortSetKSet
, symbolType = TotalFunction
, isIdem = IsNotIdem
, isAssoc = IsNotAssoc
Expand Down Expand Up @@ -755,7 +832,7 @@ defSymbols =
, resultSort = SortApp "SortSet" []
, attributes =
SymbolAttributes
{ collectionMetadata = Nothing
{ collectionMetadata = Just $ KSetMeta sortSetKSet
, symbolType = PartialFunction
, isIdem = IsIdem
, isAssoc = IsAssoc
Expand Down Expand Up @@ -1938,7 +2015,7 @@ defSymbols =
, resultSort = SortApp "SortSet" []
, attributes =
SymbolAttributes
{ collectionMetadata = Nothing
{ collectionMetadata = Just $ KSetMeta sortSetKSet
, symbolType = TotalFunction
, isIdem = IsNotIdem
, isAssoc = IsNotAssoc
Expand Down
Loading

0 comments on commit 20650ab

Please sign in to comment.