Skip to content

Commit

Permalink
Add support for GHC 9.11.20240522
Browse files Browse the repository at this point in the history
  • Loading branch information
christiaanb committed May 27, 2024
1 parent 4dadc82 commit 0cebb9c
Show file tree
Hide file tree
Showing 3 changed files with 357 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog for the [`ghc-typelits-extra`](http://hackage.haskell.org/package/ghc-typelits-extra) package

# 0.4.8
* Add support for GHC 9.11.20240522

# 0.4.7 *May 22nd, 2024*
* Add support for GHC 9.10.1
* Fix Plugin silently fails when normalizing <= in GHC 9.4+ [#50](https://github.com/clash-lang/ghc-typelits-extra/issues/50)
Expand Down
7 changes: 5 additions & 2 deletions ghc-typelits-extra.cabal
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: ghc-typelits-extra
version: 0.4.7
version: 0.4.8
synopsis: Additional type-level operations on GHC.TypeLits.Nat
description:
Additional type-level operations on @GHC.TypeLits.Nat@:
Expand Down Expand Up @@ -82,9 +82,12 @@ library
hs-source-dirs: src
if impl(ghc >= 8.0) && impl(ghc < 9.4)
hs-source-dirs: src-pre-ghc-9.4
if impl(ghc >= 9.4) && impl(ghc < 9.12)
if impl(ghc >= 9.4) && impl(ghc < 9.11)
hs-source-dirs: src-ghc-9.4
build-depends: template-haskell >= 2.17 && <2.23
if impl(ghc >= 9.11) && impl(ghc < 9.13)
hs-source-dirs: src-ghc-9.12
build-depends: template-haskell >= 2.17 && <2.23
default-language: Haskell2010
other-extensions: DataKinds
FlexibleInstances
Expand Down
349 changes: 349 additions & 0 deletions src-ghc-9.12/GHC/TypeLits/Extra/Solver.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,349 @@
{-|
Copyright : (C) 2015-2016, University of Twente
License : BSD2 (see the file LICENSE)
Maintainer : Christiaan Baaij <[email protected]>
To use the plugin, add the
@
{\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver \#-\}
@
pragma to the header of your file
-}

{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TemplateHaskellQuotes #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module GHC.TypeLits.Extra.Solver
( plugin )
where

-- external
import Control.Monad.Trans.Maybe (MaybeT (..))
import Data.Maybe (catMaybes)
import GHC.TcPluginM.Extra (evByFiatWithDependencies, tracePlugin, newWanted)
import qualified Data.Type.Ord
import qualified GHC.TypeError

-- GHC API
import GHC.Builtin.Names (eqPrimTyConKey, hasKey, getUnique)
import GHC.Builtin.Types (promotedTrueDataCon, promotedFalseDataCon)
import GHC.Builtin.Types (boolTy, naturalTy, cTupleDataCon, cTupleTyCon)
import GHC.Builtin.Types.Literals (typeNatDivTyCon, typeNatModTyCon, typeNatCmpTyCon)
import GHC.Core.Coercion (mkUnivCo)
import GHC.Core.DataCon (dataConWrapId)
import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred, IrredPred), classifyPredType)
import GHC.Core.Reduction (Reduction(..))
import GHC.Core.TyCon (TyCon)
import GHC.Core.TyCo.Rep (Type (..), TyLit (..), UnivCoProvenance (PluginProv))
import GHC.Core.Type (Kind, mkTyConApp, splitTyConApp_maybe, typeKind)
import GHC.Core.TyCo.Compare (eqType)
import GHC.Data.IOEnv (getEnv)
import GHC.Driver.Env (hsc_NC)
import GHC.Driver.Plugins (Plugin (..), defaultPlugin, purePlugin)
import GHC.Plugins (DCoVarSet, Var, emptyDVarSet, extendDVarSet, thNameToGhcNameIO)
import GHC.Tc.Plugin
(TcPluginM, tcLookupTyCon, tcPluginTrace, tcPluginIO, unsafeTcPluginTcM)
import GHC.Tc.Types
(TcPlugin(..), TcPluginSolveResult (..), TcPluginRewriter, TcPluginRewriteResult (..),
Env (env_top))
import GHC.Tc.Types.Constraint
(Ct, ctEvidence, ctEvId, ctEvPred, ctLoc, isWantedCt)
import GHC.Tc.Types.Constraint (Ct (..), DictCt(..), EqCt(..), IrredCt(..), qci_ev)
import GHC.Tc.Types.Evidence (EvTerm, EvBindsVar, Role(..), evCast, evId)
import GHC.Types.Unique.FM (UniqFM, listToUFM)
import GHC.Utils.Outputable (Outputable (..), (<+>), ($$), text)
import GHC (Name)

-- template-haskell
import qualified Language.Haskell.TH as TH

-- internal
import GHC.TypeLits.Extra.Solver.Operations
import GHC.TypeLits.Extra.Solver.Unify
import GHC.TypeLits.Extra

-- | A solver implement as a type-checker plugin for:
--
-- * 'Div': type-level 'div'
--
-- * 'Mod': type-level 'mod'
--
-- * 'FLog': type-level equivalent of <https://hackage.haskell.org/package/base-4.17.0.0/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
-- .i.e. the exact integer equivalent to "@'floor' ('logBase' x y)@"
--
-- * 'CLog': type-level equivalent of /the ceiling of/ <https://hackage.haskell.org/package/base-4.17.0.0/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
-- .i.e. the exact integer equivalent to "@'ceiling' ('logBase' x y)@"
--
-- * 'Log': type-level equivalent of <https://hackage.haskell.org/package/base-4.17.0.0/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
-- where the operation only reduces when "@'floor' ('logBase' b x) ~ 'ceiling' ('logBase' b x)@"
--
-- * 'GCD': a type-level 'gcd'
--
-- * 'LCM': a type-level 'lcm'
--
-- To use the plugin, add
--
-- @
-- {\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver \#-\}
-- @
--
-- To the header of your file.
plugin :: Plugin
plugin
= defaultPlugin
{ tcPlugin = const $ Just normalisePlugin
, pluginRecompile = purePlugin
}

normalisePlugin :: TcPlugin
normalisePlugin = tracePlugin "ghc-typelits-extra"
TcPlugin { tcPluginInit = lookupExtraDefs
, tcPluginSolve = decideEqualSOP
, tcPluginRewrite = extraRewrite
, tcPluginStop = const (return ())
}

extraRewrite :: ExtraDefs -> UniqFM TyCon TcPluginRewriter
extraRewrite defs = listToUFM
[ (gcdTyCon defs, gcdRewrite)
, (lcmTyCon defs, lcmRewrite)
]
where
gcdRewrite _ _ args@[LitTy (NumTyLit i), LitTy (NumTyLit j)] = pure $
TcPluginRewriteTo (reduce (gcdTyCon defs) args (LitTy (NumTyLit (i `gcd` j)))) []
gcdRewrite _ _ _ = pure TcPluginNoRewrite

lcmRewrite _ _ args@[LitTy (NumTyLit i), LitTy (NumTyLit j)] = pure $
TcPluginRewriteTo (reduce (lcmTyCon defs) args (LitTy (NumTyLit (i `lcm` j)))) []
lcmRewrite _ _ _ = pure TcPluginNoRewrite

reduce tc args res = Reduction co res
where
co = mkUnivCo (PluginProv "ghc-typelits-extra" emptyDVarSet) Nominal
(mkTyConApp tc args) res


decideEqualSOP :: ExtraDefs -> EvBindsVar -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult
decideEqualSOP _ _ _givens [] = return (TcPluginOk [] [])
decideEqualSOP defs _ givens wanteds = do
unit_wanteds <- catMaybes <$> mapM (runMaybeT . toSolverConstraint defs) wanteds
case unit_wanteds of
[] -> return (TcPluginOk [] [])
_ -> do
unit_givens <- catMaybes <$> mapM (runMaybeT . toSolverConstraint defs) givens
sr <- simplifyExtra defs (unit_givens ++ unit_wanteds)
tcPluginTrace "normalised" (ppr sr)
case sr of
Simplified evs new -> return (TcPluginOk (filter (isWantedCt . snd) evs) new)
Impossible eq -> return (TcPluginContradiction [fromSolverConstraint eq])

data SolverConstraint
= NatEquality Ct ExtraOp ExtraOp Normalised
| NatInequality Ct DCoVarSet ExtraOp ExtraOp Bool Normalised

instance Outputable SolverConstraint where
ppr (NatEquality ct op1 op2 norm) =
text "NatEquality" $$ ppr ct $$ ppr op1 $$ ppr op2 $$ ppr norm
ppr (NatInequality _ _ op1 op2 b norm) =
text "NatInequality" $$ ppr op1 $$ ppr op2 $$ ppr b $$ ppr norm

data SimplifyResult
= Simplified [(EvTerm,Ct)] [Ct]
| Impossible SolverConstraint

instance Outputable SimplifyResult where
ppr (Simplified evs new) =
text "Simplified" $$ text "Solved:" $$ ppr evs $$ text "New:" $$ ppr new
ppr (Impossible sct) =
text "Impossible" <+> ppr sct

simplifyExtra :: ExtraDefs -> [SolverConstraint] -> TcPluginM SimplifyResult
simplifyExtra defs eqs = tcPluginTrace "simplifyExtra" (ppr eqs) >> simples [] [] eqs
where
simples :: [Maybe (EvTerm, Ct)] -> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples evs news [] = return (Simplified (catMaybes evs) news)
simples evs news (eq@(NatEquality ct u v norm):eqs') = do
ur <- unifyExtra ct u v
tcPluginTrace "unifyExtra result" (ppr ur)
case ur of
Win -> simples (((,) <$> evMagic ct emptyDVarSet <*> pure ct):evs) news eqs'
Lose | null evs && null eqs' -> return (Impossible eq)
_ | norm == Normalised && isWantedCt ct -> do
newCt <- createWantedFromNormalised defs eq
simples (((,) <$> evMagic ct emptyDVarSet <*> pure ct):evs) (newCt:news) eqs'
Lose -> simples evs news eqs'
Draw -> simples evs news eqs'
simples evs news (eq@(NatInequality ct deps u v b norm):eqs') = do
tcPluginTrace "unifyExtra leq result" (ppr (u,v,b))
case (u,v) of
(I i,I j)
| (i <= j) == b -> simples (((,) <$> evMagic ct deps <*> pure ct):evs) news eqs'
| otherwise -> return (Impossible eq)
(p, Max x y)
| b && (p == x || p == y)
-> simples (((,) <$> evMagic ct deps <*> pure ct):evs) news eqs'

-- transform: q ~ Max x y => (p <=? q ~ True)
-- to: (p <=? Max x y) ~ True
-- and try to solve that along with the rest of the eqs'
(p, q@(V _))
| b -> case findMax q eqs of
Just (i,m) ->
simples evs news
(NatInequality ct (extendDVarSet deps i) p m b norm:eqs')
Nothing -> simples evs news eqs'
_ | norm == Normalised && isWantedCt ct -> do
newCt <- createWantedFromNormalised defs eq
simples (((,) <$> evMagic ct deps <*> pure ct):evs) (newCt:news) eqs'
_ -> simples evs news eqs'

-- look for given constraint with the form: c ~ Max x y
findMax :: ExtraOp -> [SolverConstraint] -> Maybe (Var, ExtraOp)
findMax c = go
where
go [] = Nothing
go ((NatEquality ct a b@(Max _ _) _) :_)
| c == a && not (isWantedCt ct)
= Just (ctEvId ct, b)
go ((NatEquality ct a@(Max _ _) b _) :_)
| c == b && not (isWantedCt ct)
= Just (ctEvId ct, a)
go (_:rest) = go rest


-- Extract the Nat equality constraints
toSolverConstraint :: ExtraDefs -> Ct -> MaybeT TcPluginM SolverConstraint
toSolverConstraint defs ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2
| isNatKind (typeKind t1) || isNatKind (typeKind t2)
-> do
(t1', n1) <- normaliseNat defs t1
(t2', n2) <- normaliseNat defs t2
pure (NatEquality ct t1' t2' (mergeNormalised n1 n2))
| TyConApp tc [_,cmpNat,TyConApp tt1 [],TyConApp tt2 [],TyConApp ff1 []] <- t1
, tc == ordTyCon defs
, TyConApp cmpNatTc [x,y] <- cmpNat
, cmpNatTc == typeNatCmpTyCon
, tt1 == promotedTrueDataCon
, tt2 == promotedTrueDataCon
, ff1 == promotedFalseDataCon
, TyConApp tc' [] <- t2
-> do
(x', n1) <- normaliseNat defs x
(y', n2) <- normaliseNat defs y
let res | tc' == promotedTrueDataCon
= pure (NatInequality ct emptyDVarSet x' y' True
(mergeNormalised n1 n2))
| tc' == promotedFalseDataCon
= pure (NatInequality ct emptyDVarSet x' y' False
(mergeNormalised n1 n2))
| otherwise = fail "Nothing"
res
| TyConApp tc [TyConApp ordCondTc zs, _] <- t1
, tc == assertTC defs
, TyConApp tc' [] <- t2
, tc' == cTupleTyCon 0
, ordCondTc == ordTyCon defs
, [_,cmp,lt,eq,gt] <- zs
, TyConApp tcCmpNat [x,y] <- cmp
, tcCmpNat == typeNatCmpTyCon
, TyConApp ltTc [] <- lt
, ltTc == promotedTrueDataCon
, TyConApp eqTc [] <- eq
, eqTc == promotedTrueDataCon
, TyConApp gtTc [] <- gt
, gtTc == promotedFalseDataCon
-> do
(x', n1) <- normaliseNat defs x
(y', n2) <- normaliseNat defs y
pure (NatInequality ct emptyDVarSet x' y' True (mergeNormalised n1 n2))
IrredPred (TyConApp tc [TyConApp ordCondTc zs, _])
| tc == assertTC defs
, ordCondTc == ordTyCon defs
, [_,cmp,lt,eq,gt] <- zs
, TyConApp tcCmpNat [x,y] <- cmp
, tcCmpNat == typeNatCmpTyCon
, TyConApp ltTc [] <- lt
, ltTc == promotedTrueDataCon
, TyConApp eqTc [] <- eq
, eqTc == promotedTrueDataCon
, TyConApp gtTc [] <- gt
, gtTc == promotedFalseDataCon
-> do
(x', n1) <- normaliseNat defs x
(y', n2) <- normaliseNat defs y
pure (NatInequality ct emptyDVarSet x' y' True (mergeNormalised n1 n2))
_ -> fail "Nothing"
where
isNatKind :: Kind -> Bool
isNatKind = (`eqType` naturalTy)

createWantedFromNormalised :: ExtraDefs -> SolverConstraint -> TcPluginM Ct
createWantedFromNormalised defs sct = do
let extractCtSides (NatEquality ct t1 t2 _) = (ct, reifyEOP defs t1, reifyEOP defs t2)
extractCtSides (NatInequality ct _ x y b _) =
let tc = if b then promotedTrueDataCon else promotedFalseDataCon
t1 = TyConApp (ordTyCon defs)
[ boolTy
, TyConApp typeNatCmpTyCon [reifyEOP defs x, reifyEOP defs y]
, TyConApp promotedTrueDataCon []
, TyConApp promotedTrueDataCon []
, TyConApp promotedFalseDataCon []
]
t2 = TyConApp tc []
in (ct, t1, t2)
let (ct, t1, t2) = extractCtSides sct
newPredTy <- case splitTyConApp_maybe $ ctEvPred $ ctEvidence ct of
Just (tc, [a, b, _, _]) | tc `hasKey` eqPrimTyConKey -> pure (mkTyConApp tc [a, b, t1, t2])
Just (tc, [_, b]) | tc `hasKey` getUnique (assertTC defs) -> pure (mkTyConApp tc [t1,b])
_ -> error "Impossible: neither (<=?) nor Assert"
ev <- newWanted (ctLoc ct) newPredTy
let ctN = case ct of
CQuantCan qc -> CQuantCan (qc { qci_ev = ev})
CDictCan di -> CDictCan (di { di_ev = ev})
CIrredCan ir -> CIrredCan (ir { ir_ev = ev})
CEqCan eq -> CEqCan (eq { eq_ev = ev})
CNonCanonical _ -> CNonCanonical ev
return ctN

fromSolverConstraint :: SolverConstraint -> Ct
fromSolverConstraint (NatEquality ct _ _ _) = ct
fromSolverConstraint (NatInequality ct _ _ _ _ _) = ct

lookupExtraDefs :: TcPluginM ExtraDefs
lookupExtraDefs = do
ExtraDefs <$> look ''GHC.TypeLits.Extra.Max
<*> look ''GHC.TypeLits.Extra.Min
<*> pure typeNatDivTyCon
<*> pure typeNatModTyCon
<*> look ''GHC.TypeLits.Extra.FLog
<*> look ''GHC.TypeLits.Extra.CLog
<*> look ''GHC.TypeLits.Extra.Log
<*> look ''GHC.TypeLits.Extra.GCD
<*> look ''GHC.TypeLits.Extra.LCM
<*> look ''Data.Type.Ord.OrdCond
<*> look ''GHC.TypeError.Assert
where
look nm = tcLookupTyCon =<< lookupTHName nm

lookupTHName :: TH.Name -> TcPluginM Name
lookupTHName th = do
nc <- unsafeTcPluginTcM (hsc_NC . env_top <$> getEnv)
res <- tcPluginIO $ thNameToGhcNameIO nc th
maybe (fail $ "Failed to lookup " ++ show th) return res

-- Utils
evMagic :: Ct -> DCoVarSet -> Maybe EvTerm
evMagic ct deps = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2 -> Just (evByFiatWithDependencies "ghc-typelits-extra" deps t1 t2)
IrredPred p ->
let t1 = mkTyConApp (cTupleTyCon 0) []
co = mkUnivCo (PluginProv "ghc-typelits-extra" deps) Representational t1 p
dcApp = evId (dataConWrapId (cTupleDataCon 0))
in Just (evCast dcApp co)
_ -> Nothing

0 comments on commit 0cebb9c

Please sign in to comment.