Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EXPERIMENT: Use transitive closure of tyvar substitutions #43

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 86 additions & 6 deletions src/GHC/TypeLits/KnownNat/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ import Control.Arrow ((&&&), first)
import Control.Monad.Trans.Maybe (MaybeT (..))
import Control.Monad.Trans.Writer.Strict
import Data.Maybe (catMaybes,mapMaybe)
import Data.List (partition, intersect, union, foldl')
import GHC.TcPluginM.Extra (lookupModule, lookupName, newWanted,
tracePlugin)
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra (flattenGivens, mkSubst', substType)
import GHC.TcPluginM.Extra (flattenGivens, mkSubst')
#endif
import GHC.TypeLits.Normalise.SOP (SOP (..), Product (..), Symbol (..))
import GHC.TypeLits.Normalise.Unify (CType (..),normaliseNat,reifySOP)
Expand Down Expand Up @@ -133,8 +134,9 @@ import GHC.Core.Type
irrelevantMult)
import GHC.Data.FastString (fsLit)
import GHC.Driver.Plugins (Plugin (..), defaultPlugin, purePlugin)
import GHC.Utils.Outputable (Outputable (..), ($$), text)
import GHC.Tc.Instance.Family (tcInstNewTyCon_maybe)
import GHC.Tc.Plugin (TcPluginM, tcLookupClass, getInstEnvs)
import GHC.Tc.Plugin (TcPluginM, tcLookupClass, tcPluginTrace, getInstEnvs)
import GHC.Tc.Types (TcPlugin(..), TcPluginResult (..))
import GHC.Tc.Types.Constraint
(Ct, ctEvExpr, ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isWanted,
Expand All @@ -144,10 +146,11 @@ import GHC.Tc.Types.Evidence
import GHC.Types.Id (idType)
import GHC.Types.Name (nameModule_maybe, nameOccName)
import GHC.Types.Name.Occurrence (mkTcOcc, occNameString)
import GHC.Types.Var (DFunId)
import GHC.Types.Var (TcTyVar, DFunId)
import GHC.Unit.Module (mkModuleName, moduleName, moduleNameString)
#else
import Class (Class, classMethods, className, classTyCon)
import Outputable (Outputable (..), ($$), text)
#if MIN_VERSION_ghc(8,6,0)
import Coercion (Role (Representational), mkUnivCo)
#endif
Expand Down Expand Up @@ -177,7 +180,7 @@ import TcPluginM (unsafeTcPluginTcM)
#if !MIN_VERSION_ghc(8,4,0)
import TcPluginM (zonkCt)
#endif
import TcPluginM (TcPluginM, tcLookupClass, getInstEnvs)
import TcPluginM (TcPluginM, tcLookupClass, getInstEnvs, tcPluginTrace)
import TcRnTypes (TcPlugin(..), TcPluginResult (..))
import TcTypeNats (typeNatAddTyCon, typeNatSubTyCon)
#if MIN_VERSION_ghc(8,4,0)
Expand All @@ -193,7 +196,7 @@ import TyCoRep (Type (..), TyLit (..))
import TyCoRep (UnivCoProvenance (PluginProv))
import TysWiredIn (boolTy)
#endif
import Var (DFunId)
import Var (TcTyVar, DFunId)

#if MIN_VERSION_ghc(8,10,0)
import Constraint
Expand Down Expand Up @@ -226,6 +229,9 @@ data KnownNatDefs
-- knownnat from the un-flattened version that we work with internally.
newtype Orig a = Orig { unOrig :: a }

instance Outputable a => Outputable(Orig a) where
ppr (Orig a) = text "Outputable " $$ ppr a

-- | KnownNat constraints
type KnConstraint = (Ct -- The constraint
,Class -- KnownNat class
Expand Down Expand Up @@ -331,6 +337,74 @@ normalisePlugin = tracePlugin "ghc-typelits-knownnat"
, tcPluginStop = const (return ())
}

-- | give transitive closure of type variable rewrites + their non type variable
-- rewrites.
-- Ex. Given the following substitutions:
-- a ~ b, b ~ c, c ~ d, b ~ F z, x ~ z. x ~ u, u ~ G a, u ~ H b
-- Then then the result is
-- [ ([a,b,c,d], [F z])
-- , (x,z,u), [G a, H b]
-- ]
substTransClosures :: [(TcTyVar, Type)] -> [([TcTyVar], [Type])]
substTransClosures substs = zip closures nonTyVarSubs
where
initialClosure (tv1, TyVarTy tv2) = [tv1, tv2]
initialClosure (tv, _) = [tv]

merge tcs tc1 = merged:(map snd toKeep)
where
(toKeep, toMerge) = partition (null . fst)
$ map (\tc2 -> (intersect tc1 tc2, tc2)) tcs
merged = union tc1 (concat $ map snd toMerge)

closures = foldl' merge [] (map initialClosure substs)

getNonTyVarSubst _ (_, TyVarTy _) = Nothing
getNonTyVarSubst tv1 (tv2, subst) | tv1 == tv2 = Just subst
getNonTyVarSubst _ _ = Nothing
nonTyVarSubs = map
(concat . map (\tv -> mapMaybe (getNonTyVarSubst tv) substs))
closures

-- Base implementation copied from: https://github.com/clash-lang/ghc-tcplugins-extra
-- __NB:__ Doesn't substitute under binders
substType'
:: [([TcTyVar], [Type])]
-> Type
-> Type
substType' [] tv = tv
substType' ((_, []):subst) tv = substType' subst tv
substType' ((tvs, (t:_)):subst) tv@(TyVarTy v) | elem v tvs = t
| otherwise = substType' subst tv

substType' subst (AppTy t1 t2) =
AppTy (substType' subst t1) (substType' subst t2)
substType' subst (TyConApp tc xs) =
TyConApp tc (map (substType' subst) xs)
substType' _subst t@(ForAllTy _tv _ty) =
-- TODO: Is it safe to do "dumb" substitution under binders?
-- ForAllTy tv (substType' subst ty)
t
#if __GLASGOW_HASKELL__ >= 900
substType' subst (FunTy k1 k2 t1 t2) =
FunTy k1 k2 (substType' subst t1) (substType' subst t2)
#elif __GLASGOW_HASKELL__ >= 809
substType' subst (FunTy af t1 t2) =
FunTy af (substType' subst t1) (substType' subst t2)
#elif __GLASGOW_HASKELL__ >= 802
substType' subst (FunTy t1 t2) =
FunTy (substType' subst t1) (substType' subst t2)
#elif __GLASGOW_HASKELL__ < 711
substType' subst (FunTy t1 t2) =
FunTy (substType' subst t1) (substType' subst t2)
#endif
substType' _ l@(LitTy _) = l
#if __GLASGOW_HASKELL__ > 711
substType' subst (CastTy ty co) =
CastTy (substType' subst ty) co
substType' _ co@(CoercionTy _) = co
#endif

solveKnownNat :: KnownNatDefs -> [Ct] -> [Ct] -> [Ct]
-> TcPluginM TcPluginResult
solveKnownNat _defs _givens _deriveds [] = return (TcPluginOk [] [])
Expand All @@ -340,11 +414,16 @@ solveKnownNat defs givens _deriveds wanteds = do
#if MIN_VERSION_ghc(8,4,0)
subst = map fst
$ mkSubst' givens
kn_wanteds = map (\(x,y,z,orig) -> (x,y,substType subst z,orig))
tcSubst = substTransClosures subst
tcPluginTrace "subst" (ppr subst)
tcPluginTrace "transitive closure" (ppr tcSubst)
let kn_wanteds = map (\(x,y,z,orig) -> (x,y,substType' tcSubst z,orig))
$ mapMaybe (toKnConstraint defs) wanteds'
#else
kn_wanteds = mapMaybe (toKnConstraint defs) wanteds'

#endif
tcPluginTrace "kn_wanteds" (ppr kn_wanteds)
case kn_wanteds of
[] -> return (TcPluginOk [] [])
_ -> do
Expand All @@ -354,6 +433,7 @@ solveKnownNat defs givens _deriveds wanteds = do
#else
given_map <- mapM (fmap toGivenEntry . zonkCt) givens
#endif
tcPluginTrace "given_map" (ppr given_map)
-- Try to solve the wanted KnownNat constraints given the [G]iven
-- KnownNat constraints
(solved,new) <- (unzip . catMaybes) <$> (mapM (constraintToEvTerm defs given_map) kn_wanteds)
Expand Down
12 changes: 6 additions & 6 deletions tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,6 @@ test16 _ _ = natVal (Proxy @(Foo 1 + 7 + Foo 1))
test17 :: KnownNat (4 + 2 * Foo 1 + Foo 1) => Proxy (Foo 1) -> Proxy (4 + 2 * Foo 1 + Foo 1) -> Number
test17 _ _ = natVal (Proxy @(2 * Foo 1 + 7 + Foo 1))

data SNat :: Nat -> Type where
SNat :: KnownNat n => SNat n

instance Show (SNat n) where
show s@SNat = show (natVal s)

addSNat :: SNat a -> SNat b -> SNat (a + b)
addSNat SNat SNat = SNat

Expand Down Expand Up @@ -188,6 +182,9 @@ test28 :: forall m n . (KnownNat m, (2*n) ~ m) => Proxy m -> Natural
test28 _ = natVal @n Proxy
#endif

test29 :: forall a b . (b ~ (2^a)) => SNat a -> SNat (Log b)
test29 SNat = SNat @(Log b)

tests :: TestTree
tests = testGroup "ghc-typelits-natnormalise"
[ testGroup "Basic functionality"
Expand Down Expand Up @@ -229,6 +226,9 @@ tests = testGroup "ghc-typelits-natnormalise"
show (test28 (Proxy @10)) @?=
"5"
#endif
, testCase "forall a b . (b ~ (2^a)) => SNat a -> SNat (Log b)" $
show (test29 (SNat @0)) @?=
"0"
],
testGroup "Implications"
[ testCase "KnownNat m => KnownNat (m*m); @5" $
Expand Down
7 changes: 7 additions & 0 deletions tests/TestFunctions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

module TestFunctions where

import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import Data.Type.Bool (If)
import GHC.TypeLits.KnownNat
Expand Down Expand Up @@ -38,6 +39,12 @@ type family Min (a :: Nat) (b :: Nat) :: Nat where
Min 0 b = 0 -- See [Note: single equation TFs are treated like synonyms]
Min a b = If (a <=? b) a b

data SNat :: Nat -> Type where
SNat :: KnownNat n => SNat n

instance Show (SNat n) where
show s@SNat = show (natVal s)

-- Unary functions.
#if __GLASGOW_HASKELL__ >= 802
withNat :: Natural -> (forall n. (KnownNat n) => Proxy n -> r) -> r
Expand Down