From 3c85b1905b2be7ec1bdb73bdc85675e682e3424c Mon Sep 17 00:00:00 2001 From: Georgy Lukyanov Date: Thu, 24 Oct 2024 15:01:37 +0200 Subject: [PATCH] Introduce RewriteRuleAppliedData and RewriteGroupApplicationData --- booster/library/Booster/Pattern/Rewrite.hs | 167 +++++++++++++-------- 1 file changed, 104 insertions(+), 63 deletions(-) diff --git a/booster/library/Booster/Pattern/Rewrite.hs b/booster/library/Booster/Pattern/Rewrite.hs index 30f3df2632..ead6f9e92e 100644 --- a/booster/library/Booster/Pattern/Rewrite.hs +++ b/booster/library/Booster/Pattern/Rewrite.hs @@ -197,64 +197,73 @@ rewriteStep cutLabels terminalLabels pat = do processGroups rest >>= \case RewriteStuck{} -> pure $ RewriteTrivial pat other -> pure other - AppliedRules ([], _remainder) -> + AppliedRules (RewriteGroupApplicationData{ruleApplicationData = []}) -> -- TODO check that remainder is trivial, abort otherwise processGroups rest - AppliedRules ([(rule, newPat, _subst, _rulePred)], remainder) - | not (Set.null remainder) && not (any isFalse remainder) -> do - -- a non-trivial remainder with a single applicable rule is - -- an indication if semantics incompleteness: abort - -- TODO refactor remainder check into a function and reuse below - solver <- getSolver - satRes <- SMT.isSat solver (Set.toList $ pat.constraints <> remainder) pat.substitution - throw $ - RewriteRemainderPredicate [rule] satRes . coerce . foldl1 AndTerm $ - map coerce . Set.toList $ - remainder - -- a single rule applies, see if it's special and return an appropriate result - | labelOf rule `elem` cutLabels -> - pure $ RewriteCutPoint (labelOf rule) (uniqueId rule) pat newPat - | labelOf rule `elem` terminalLabels -> - pure $ RewriteTerminal (labelOf rule) (uniqueId rule) newPat - | otherwise -> - pure $ RewriteFinished (Just $ ruleLabelOrLocT rule) (Just $ uniqueId rule) newPat - AppliedRules (xs, remainder) -> do - -- multiple rules apply, analyse brunching and remainders - if any isFalse remainder - then do - logRemainder (map (\(r, _, _, _) -> r) xs) SMT.IsUnsat remainder - -- the remainder predicate is trivially false, return the branching result - pure $ mkBranch pat xs - else do - -- otherwise, we need to check the remainder predicate with the SMT solver - -- and construct an additional remainder branch if needed + AppliedRules + ( RewriteGroupApplicationData + { ruleApplicationData = [(rule, applied@RewriteRuleAppliedData{})] + , remainderPrediate = groupRemainderPrediate + } + ) + | not (Set.null groupRemainderPrediate) && not (any isFalse groupRemainderPrediate) -> do + -- a non-trivial remainder with a single applicable rule is + -- an indication if semantics incompleteness: abort + -- TODO refactor remainder check into a function and reuse below solver <- getSolver - SMT.isSat solver (Set.toList $ pat.constraints <> remainder) pat.substitution >>= \case - SMT.IsUnsat -> do - -- the remainder condition is unsatisfiable: no need to consider the remainder branch. - logRemainder (map (\(r, _, _, _) -> r) xs) SMT.IsUnsat remainder - pure $ mkBranch pat xs - satRes@(SMT.IsSat{}) -> do - -- the remainder condition is satisfiable. - -- TODO construct the remainder branch and consider it - -- To construct the "remainder pattern", - -- we add the remainder condition to the predicates of the @pattr@ - throwRemainder (map (\(r, _p, _subst, _) -> r) xs) satRes remainder - satRes@SMT.IsUnknown{} -> do - -- solver cannot solve the remainder - -- TODO descend into the remainder branch anyway - throwRemainder (map (\(r, _p, _subst, _) -> r) xs) satRes remainder + satRes <- SMT.isSat solver (Set.toList $ pat.constraints <> groupRemainderPrediate) pat.substitution + throw $ + RewriteRemainderPredicate [rule] satRes . coerce . foldl1 AndTerm $ + map coerce . Set.toList $ + groupRemainderPrediate + -- a single rule applies, see if it's special and return an appropriate result + | labelOf rule `elem` cutLabels -> + pure $ RewriteCutPoint (labelOf rule) (uniqueId rule) pat applied.rewritten + | labelOf rule `elem` terminalLabels -> + pure $ RewriteTerminal (labelOf rule) (uniqueId rule) applied.rewritten + | otherwise -> + pure $ RewriteFinished (Just $ ruleLabelOrLocT rule) (Just $ uniqueId rule) applied.rewritten + AppliedRules + (RewriteGroupApplicationData{ruleApplicationData = xs, remainderPrediate = groupRemainderPrediate}) -> do + -- multiple rules apply, analyse brunching and remainders + if any isFalse groupRemainderPrediate + then do + logRemainder (map fst xs) SMT.IsUnsat groupRemainderPrediate + -- the remainder predicate is trivially false, return the branching result + pure $ mkBranch pat xs + else do + -- otherwise, we need to check the remainder predicate with the SMT solver + -- and construct an additional remainder branch if needed + solver <- getSolver + SMT.isSat solver (Set.toList $ pat.constraints <> groupRemainderPrediate) pat.substitution >>= \case + SMT.IsUnsat -> do + -- the remainder condition is unsatisfiable: no need to consider the remainder branch. + logRemainder (map fst xs) SMT.IsUnsat groupRemainderPrediate + pure $ mkBranch pat xs + satRes@(SMT.IsSat{}) -> do + -- the remainder condition is satisfiable. + -- TODO construct the remainder branch and consider it + -- To construct the "remainder pattern", + -- we add the remainder condition to the predicates of the @pattr@ + throwRemainder (map fst xs) satRes groupRemainderPrediate + satRes@SMT.IsUnknown{} -> do + -- solver cannot solve the remainder + -- TODO descend into the remainder branch anyway + throwRemainder (map fst xs) satRes groupRemainderPrediate mkBranch :: Pattern -> - [(RewriteRule "Rewrite", Pattern, Substitution, Maybe Predicate)] -> + [(RewriteRule "Rewrite", RewriteRuleAppliedData)] -> RewriteResult Pattern mkBranch base leafs = let ruleLabelOrLocT = renderOneLineText . ruleLabelOrLoc uniqueId = (.uniqueId) . (.attributes) in RewriteBranch base $ NE.fromList $ - map (\(r, p, subst, rulePred) -> (ruleLabelOrLocT r, uniqueId r, p, rulePred, subst)) leafs + map + ( \(rule, RewriteRuleAppliedData{rewritten, rulePredicate, ruleSubstitution}) -> (ruleLabelOrLocT rule, uniqueId rule, rewritten, rulePredicate, ruleSubstitution) + ) + leafs -- abort rewriting by throwing a remainder predicate as an exception, to be caught and processed in @performRewrite@ throwRemainder :: @@ -299,12 +308,25 @@ runRewriteRuleAppT action = Left RewriteRuleTrivial -> pure Trivial Right result -> pure (Applied result) +{- | Rewrite rule application result. + + Note that we only really every need the payload to be @'RewriteRuleAppliedData'@, + but we make this type parametric so that it can be the result of @'runRewriteRuleAppT'@ +-} data RewriteRuleAppResult a = Applied a | NotApplied | Trivial deriving (Show, Eq, Functor) +data RewriteRuleAppliedData = RewriteRuleAppliedData + { rewritten :: Pattern + , remainderPredicate :: Maybe Predicate + , ruleSubstitution :: Substitution + , rulePredicate :: Maybe Predicate + } + deriving (Show, Eq) + returnTrivial, returnNotApplied :: Monad m => RewriteRuleAppT m a returnTrivial = throwE RewriteRuleTrivial returnNotApplied = throwE RewriteRuleNotApplied @@ -325,7 +347,7 @@ applyRule :: LoggerMIO io => Pattern -> RewriteRule "Rewrite" -> - RewriteT io (RewriteRuleAppResult (Pattern, Predicate, Substitution, Maybe Predicate)) + RewriteT io (RewriteRuleAppResult RewriteRuleAppliedData) applyRule pat@Pattern{ceilConditions} rule = withRuleContext rule $ runRewriteRuleAppT $ @@ -439,7 +461,13 @@ applyRule pat@Pattern{ceilConditions} rule = case unclearRequiresAfterSmt of [] -> withPatternContext rewritten $ - pure (rewritten, Predicate FalseBool, modifiedPatternSubst `compose` ruleSubstitution, Nothing) + pure + RewriteRuleAppliedData + { rewritten + , remainderPredicate = Nothing + , ruleSubstitution = modifiedPatternSubst `compose` ruleSubstitution + , rulePredicate = Nothing + } _ -> do rulePredicate <- mkSimplifiedRulePredicate (modifiedPatternSubst `compose` ruleSubstitution) -- the requires clause was unclear: @@ -449,11 +477,12 @@ applyRule pat@Pattern{ceilConditions} rule = let rewritten' = rewritten{constraints = rewritten.constraints <> Set.fromList unclearRequiresAfterSmt} in withPatternContext rewritten' $ pure - ( rewritten' - , Predicate $ NotBool $ coerce $ collapseAndBools unclearRequiresAfterSmt - , modifiedPatternSubst `compose` ruleSubstitution - , Just rulePredicate - ) + RewriteRuleAppliedData + { rewritten = rewritten' + , remainderPredicate = Just . Predicate . NotBool . coerce $ collapseAndBools unclearRequiresAfterSmt + , ruleSubstitution = modifiedPatternSubst `compose` ruleSubstitution + , rulePredicate = Just rulePredicate + } where filterOutKnownConstraints :: Set.Set Predicate -> [Predicate] -> RewriteT io [Predicate] filterOutKnownConstraints priorKnowledge constraitns = do @@ -609,34 +638,46 @@ applyRule pat@Pattern{ceilConditions} rule = data RuleGroupApplication a = OnlyTrivial | AppliedRules a +data RewriteGroupApplicationData = RewriteGroupApplicationData + { ruleApplicationData :: [(RewriteRule "Rewrite", RewriteRuleAppliedData)] + , remainderPrediate :: Set.Set Predicate + } + ruleGroupPriority :: [RewriteRule a] -> Maybe Priority ruleGroupPriority = \case [] -> Nothing (rule : _) -> Just rule.attributes.priority {- | Given a list of rule application attempts, i.e. a result of applying a priority group of rules in parallel, - post process them: + post-process them: - filter-out trivial and failed applications - extract (possibly trivial) remainder predicates of every rule - and return them as a set relating to the whole group + and return them as a set relating to the whole group. -} postProcessRuleAttempts :: - [(RewriteRule "Rewrite", RewriteRuleAppResult (Pattern, Predicate, Substitution, Maybe Predicate))] -> - RuleGroupApplication - ([(RewriteRule "Rewrite", Pattern, Substitution, Maybe Predicate)], Set.Set Predicate) + [(RewriteRule "Rewrite", RewriteRuleAppResult RewriteRuleAppliedData)] -> + RuleGroupApplication RewriteGroupApplicationData postProcessRuleAttempts = \case - [] -> AppliedRules ([], mempty) + [] -> AppliedRules (RewriteGroupApplicationData{ruleApplicationData = [], remainderPrediate = mempty}) apps -> case filter ((/= NotApplied) . snd) apps of - [] -> AppliedRules ([], mempty) + [] -> AppliedRules (RewriteGroupApplicationData{ruleApplicationData = [], remainderPrediate = mempty}) xs | all ((== Trivial) . snd) xs -> OnlyTrivial - | otherwise -> go ([], mempty) xs + | otherwise -> + let (ruleApplicationData, remainderPrediate) = go ([], mempty) xs + in AppliedRules (RewriteGroupApplicationData{ruleApplicationData, remainderPrediate}) where go acc@(accPatterns, accRemainders) = \case - [] -> AppliedRules (reverse accPatterns, accRemainders) + [] -> (reverse accPatterns, accRemainders) ((rule, appRes) : xs) -> case appRes of - Applied (pat, remainder, subst, rulePred) -> go ((rule, pat, subst, rulePred) : accPatterns, Set.singleton remainder <> accRemainders) xs + Applied + appliedData -> + go + ( (rule, appliedData{remainderPredicate = Nothing}) : accPatterns + , (Set.singleton $ fromMaybe (Predicate FalseBool) appliedData.remainderPredicate) <> accRemainders + ) + xs NotApplied -> go acc xs Trivial -> go acc xs