Skip to content

Commit

Permalink
Introduce RewriteRuleAppliedData and RewriteGroupApplicationData
Browse files Browse the repository at this point in the history
  • Loading branch information
geo2a committed Oct 24, 2024
1 parent 8c9db24 commit 3c85b19
Showing 1 changed file with 104 additions and 63 deletions.
167 changes: 104 additions & 63 deletions booster/library/Booster/Pattern/Rewrite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -197,64 +197,73 @@ rewriteStep cutLabels terminalLabels pat = do
processGroups rest >>= \case
RewriteStuck{} -> pure $ RewriteTrivial pat
other -> pure other
AppliedRules ([], _remainder) ->
AppliedRules (RewriteGroupApplicationData{ruleApplicationData = []}) ->
-- TODO check that remainder is trivial, abort otherwise
processGroups rest
AppliedRules ([(rule, newPat, _subst, _rulePred)], remainder)
| not (Set.null remainder) && not (any isFalse remainder) -> do
-- a non-trivial remainder with a single applicable rule is
-- an indication if semantics incompleteness: abort
-- TODO refactor remainder check into a function and reuse below
solver <- getSolver
satRes <- SMT.isSat solver (Set.toList $ pat.constraints <> remainder) pat.substitution
throw $
RewriteRemainderPredicate [rule] satRes . coerce . foldl1 AndTerm $
map coerce . Set.toList $
remainder
-- a single rule applies, see if it's special and return an appropriate result
| labelOf rule `elem` cutLabels ->
pure $ RewriteCutPoint (labelOf rule) (uniqueId rule) pat newPat
| labelOf rule `elem` terminalLabels ->
pure $ RewriteTerminal (labelOf rule) (uniqueId rule) newPat
| otherwise ->
pure $ RewriteFinished (Just $ ruleLabelOrLocT rule) (Just $ uniqueId rule) newPat
AppliedRules (xs, remainder) -> do
-- multiple rules apply, analyse brunching and remainders
if any isFalse remainder
then do
logRemainder (map (\(r, _, _, _) -> r) xs) SMT.IsUnsat remainder
-- the remainder predicate is trivially false, return the branching result
pure $ mkBranch pat xs
else do
-- otherwise, we need to check the remainder predicate with the SMT solver
-- and construct an additional remainder branch if needed
AppliedRules
( RewriteGroupApplicationData
{ ruleApplicationData = [(rule, applied@RewriteRuleAppliedData{})]
, remainderPrediate = groupRemainderPrediate
}
)
| not (Set.null groupRemainderPrediate) && not (any isFalse groupRemainderPrediate) -> do
-- a non-trivial remainder with a single applicable rule is
-- an indication if semantics incompleteness: abort
-- TODO refactor remainder check into a function and reuse below
solver <- getSolver
SMT.isSat solver (Set.toList $ pat.constraints <> remainder) pat.substitution >>= \case
SMT.IsUnsat -> do
-- the remainder condition is unsatisfiable: no need to consider the remainder branch.
logRemainder (map (\(r, _, _, _) -> r) xs) SMT.IsUnsat remainder
pure $ mkBranch pat xs
satRes@(SMT.IsSat{}) -> do
-- the remainder condition is satisfiable.
-- TODO construct the remainder branch and consider it
-- To construct the "remainder pattern",
-- we add the remainder condition to the predicates of the @pattr@
throwRemainder (map (\(r, _p, _subst, _) -> r) xs) satRes remainder
satRes@SMT.IsUnknown{} -> do
-- solver cannot solve the remainder
-- TODO descend into the remainder branch anyway
throwRemainder (map (\(r, _p, _subst, _) -> r) xs) satRes remainder
satRes <- SMT.isSat solver (Set.toList $ pat.constraints <> groupRemainderPrediate) pat.substitution
throw $
RewriteRemainderPredicate [rule] satRes . coerce . foldl1 AndTerm $
map coerce . Set.toList $
groupRemainderPrediate
-- a single rule applies, see if it's special and return an appropriate result
| labelOf rule `elem` cutLabels ->
pure $ RewriteCutPoint (labelOf rule) (uniqueId rule) pat applied.rewritten
| labelOf rule `elem` terminalLabels ->
pure $ RewriteTerminal (labelOf rule) (uniqueId rule) applied.rewritten
| otherwise ->
pure $ RewriteFinished (Just $ ruleLabelOrLocT rule) (Just $ uniqueId rule) applied.rewritten
AppliedRules
(RewriteGroupApplicationData{ruleApplicationData = xs, remainderPrediate = groupRemainderPrediate}) -> do
-- multiple rules apply, analyse brunching and remainders
if any isFalse groupRemainderPrediate
then do
logRemainder (map fst xs) SMT.IsUnsat groupRemainderPrediate
-- the remainder predicate is trivially false, return the branching result
pure $ mkBranch pat xs
else do
-- otherwise, we need to check the remainder predicate with the SMT solver
-- and construct an additional remainder branch if needed
solver <- getSolver
SMT.isSat solver (Set.toList $ pat.constraints <> groupRemainderPrediate) pat.substitution >>= \case
SMT.IsUnsat -> do
-- the remainder condition is unsatisfiable: no need to consider the remainder branch.
logRemainder (map fst xs) SMT.IsUnsat groupRemainderPrediate
pure $ mkBranch pat xs
satRes@(SMT.IsSat{}) -> do
-- the remainder condition is satisfiable.
-- TODO construct the remainder branch and consider it
-- To construct the "remainder pattern",
-- we add the remainder condition to the predicates of the @pattr@
throwRemainder (map fst xs) satRes groupRemainderPrediate
satRes@SMT.IsUnknown{} -> do
-- solver cannot solve the remainder
-- TODO descend into the remainder branch anyway
throwRemainder (map fst xs) satRes groupRemainderPrediate

mkBranch ::
Pattern ->
[(RewriteRule "Rewrite", Pattern, Substitution, Maybe Predicate)] ->
[(RewriteRule "Rewrite", RewriteRuleAppliedData)] ->
RewriteResult Pattern
mkBranch base leafs =
let ruleLabelOrLocT = renderOneLineText . ruleLabelOrLoc
uniqueId = (.uniqueId) . (.attributes)
in RewriteBranch base $
NE.fromList $
map (\(r, p, subst, rulePred) -> (ruleLabelOrLocT r, uniqueId r, p, rulePred, subst)) leafs
map
( \(rule, RewriteRuleAppliedData{rewritten, rulePredicate, ruleSubstitution}) -> (ruleLabelOrLocT rule, uniqueId rule, rewritten, rulePredicate, ruleSubstitution)
)
leafs

-- abort rewriting by throwing a remainder predicate as an exception, to be caught and processed in @performRewrite@
throwRemainder ::
Expand Down Expand Up @@ -299,12 +308,25 @@ runRewriteRuleAppT action =
Left RewriteRuleTrivial -> pure Trivial
Right result -> pure (Applied result)

{- | Rewrite rule application result.
Note that we only really every need the payload to be @'RewriteRuleAppliedData'@,
but we make this type parametric so that it can be the result of @'runRewriteRuleAppT'@
-}
data RewriteRuleAppResult a
= Applied a
| NotApplied
| Trivial
deriving (Show, Eq, Functor)

data RewriteRuleAppliedData = RewriteRuleAppliedData
{ rewritten :: Pattern
, remainderPredicate :: Maybe Predicate
, ruleSubstitution :: Substitution
, rulePredicate :: Maybe Predicate
}
deriving (Show, Eq)

returnTrivial, returnNotApplied :: Monad m => RewriteRuleAppT m a
returnTrivial = throwE RewriteRuleTrivial
returnNotApplied = throwE RewriteRuleNotApplied
Expand All @@ -325,7 +347,7 @@ applyRule ::
LoggerMIO io =>
Pattern ->
RewriteRule "Rewrite" ->
RewriteT io (RewriteRuleAppResult (Pattern, Predicate, Substitution, Maybe Predicate))
RewriteT io (RewriteRuleAppResult RewriteRuleAppliedData)
applyRule pat@Pattern{ceilConditions} rule =
withRuleContext rule $
runRewriteRuleAppT $
Expand Down Expand Up @@ -439,7 +461,13 @@ applyRule pat@Pattern{ceilConditions} rule =
case unclearRequiresAfterSmt of
[] ->
withPatternContext rewritten $
pure (rewritten, Predicate FalseBool, modifiedPatternSubst `compose` ruleSubstitution, Nothing)
pure
RewriteRuleAppliedData
{ rewritten
, remainderPredicate = Nothing
, ruleSubstitution = modifiedPatternSubst `compose` ruleSubstitution
, rulePredicate = Nothing
}
_ -> do
rulePredicate <- mkSimplifiedRulePredicate (modifiedPatternSubst `compose` ruleSubstitution)
-- the requires clause was unclear:
Expand All @@ -449,11 +477,12 @@ applyRule pat@Pattern{ceilConditions} rule =
let rewritten' = rewritten{constraints = rewritten.constraints <> Set.fromList unclearRequiresAfterSmt}
in withPatternContext rewritten' $
pure
( rewritten'
, Predicate $ NotBool $ coerce $ collapseAndBools unclearRequiresAfterSmt
, modifiedPatternSubst `compose` ruleSubstitution
, Just rulePredicate
)
RewriteRuleAppliedData
{ rewritten = rewritten'
, remainderPredicate = Just . Predicate . NotBool . coerce $ collapseAndBools unclearRequiresAfterSmt
, ruleSubstitution = modifiedPatternSubst `compose` ruleSubstitution
, rulePredicate = Just rulePredicate
}
where
filterOutKnownConstraints :: Set.Set Predicate -> [Predicate] -> RewriteT io [Predicate]
filterOutKnownConstraints priorKnowledge constraitns = do
Expand Down Expand Up @@ -609,34 +638,46 @@ applyRule pat@Pattern{ceilConditions} rule =

data RuleGroupApplication a = OnlyTrivial | AppliedRules a

data RewriteGroupApplicationData = RewriteGroupApplicationData
{ ruleApplicationData :: [(RewriteRule "Rewrite", RewriteRuleAppliedData)]
, remainderPrediate :: Set.Set Predicate
}

ruleGroupPriority :: [RewriteRule a] -> Maybe Priority
ruleGroupPriority = \case
[] -> Nothing
(rule : _) -> Just rule.attributes.priority

{- | Given a list of rule application attempts, i.e. a result of applying a priority group of rules in parallel,
post process them:
post-process them:
- filter-out trivial and failed applications
- extract (possibly trivial) remainder predicates of every rule
and return them as a set relating to the whole group
and return them as a set relating to the whole group.
-}
postProcessRuleAttempts ::
[(RewriteRule "Rewrite", RewriteRuleAppResult (Pattern, Predicate, Substitution, Maybe Predicate))] ->
RuleGroupApplication
([(RewriteRule "Rewrite", Pattern, Substitution, Maybe Predicate)], Set.Set Predicate)
[(RewriteRule "Rewrite", RewriteRuleAppResult RewriteRuleAppliedData)] ->
RuleGroupApplication RewriteGroupApplicationData
postProcessRuleAttempts = \case
[] -> AppliedRules ([], mempty)
[] -> AppliedRules (RewriteGroupApplicationData{ruleApplicationData = [], remainderPrediate = mempty})
apps -> case filter ((/= NotApplied) . snd) apps of
[] -> AppliedRules ([], mempty)
[] -> AppliedRules (RewriteGroupApplicationData{ruleApplicationData = [], remainderPrediate = mempty})
xs
| all ((== Trivial) . snd) xs -> OnlyTrivial
| otherwise -> go ([], mempty) xs
| otherwise ->
let (ruleApplicationData, remainderPrediate) = go ([], mempty) xs
in AppliedRules (RewriteGroupApplicationData{ruleApplicationData, remainderPrediate})
where
go acc@(accPatterns, accRemainders) = \case
[] -> AppliedRules (reverse accPatterns, accRemainders)
[] -> (reverse accPatterns, accRemainders)
((rule, appRes) : xs) ->
case appRes of
Applied (pat, remainder, subst, rulePred) -> go ((rule, pat, subst, rulePred) : accPatterns, Set.singleton remainder <> accRemainders) xs
Applied
appliedData ->
go
( (rule, appliedData{remainderPredicate = Nothing}) : accPatterns
, (Set.singleton $ fromMaybe (Predicate FalseBool) appliedData.remainderPredicate) <> accRemainders
)
xs
NotApplied -> go acc xs
Trivial -> go acc xs

Expand Down

0 comments on commit 3c85b19

Please sign in to comment.