From 6de31070c86ad451abe329940ba4a0a4b571fb58 Mon Sep 17 00:00:00 2001 From: Felix Klein Date: Sun, 1 Sep 2024 20:27:26 +0200 Subject: [PATCH] Add CLog with well-defined zero case --- src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs | 1 + src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs | 1 + src/GHC/TypeLits/Extra.hs | 29 +++++++++- src/GHC/TypeLits/Extra/Solver/Operations.hs | 19 +++++++ src/GHC/TypeLits/Extra/Solver/Unify.hs | 13 +++++ tests/Main.hs | 60 ++++++++++++++++++++ 6 files changed, 122 insertions(+), 1 deletion(-) diff --git a/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs b/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs index 8d2ed0e..755ff9a 100644 --- a/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs +++ b/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs @@ -327,6 +327,7 @@ lookupExtraDefs = do <*> look ''GHC.TypeLits.Extra.LCM <*> look ''Data.Type.Ord.OrdCond <*> look ''GHC.TypeError.Assert + <*> look ''GHC.TypeLits.Extra.CLogWZ where look nm = tcLookupTyCon =<< lookupTHName nm diff --git a/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs b/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs index 3d9a9d3..8e4ab86 100644 --- a/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs +++ b/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs @@ -331,6 +331,7 @@ lookupExtraDefs = do <*> pure typeNatLeqTyCon <*> pure typeNatLeqTyCon #endif + <*> look md "CLogWZ" where look md s = tcLookupTyCon =<< lookupName md (mkTcOcc s) myModule = mkModuleName "GHC.TypeLits.Extra" diff --git a/src/GHC/TypeLits/Extra.hs b/src/GHC/TypeLits/Extra.hs index 9cebcd0..2e76329 100644 --- a/src/GHC/TypeLits/Extra.hs +++ b/src/GHC/TypeLits/Extra.hs @@ -70,6 +70,7 @@ module GHC.TypeLits.Extra -- ** Logarithm , FLog , CLog + , CLogWZ -- *** Exact logarithm , Log -- Numeric @@ -101,7 +102,8 @@ import GHC.TypeLits as N #if MIN_VERSION_ghc(8,4,0) import GHC.TypeLits (Div, Mod) #endif -import GHC.TypeLits.KnownNat (KnownNat2 (..), SNatKn (..), nameToSymbol) +import GHC.TypeLits.KnownNat (KnownNat2 (..), KnownNat3 (..) + ,SNatKn (..), nameToSymbol) #if MIN_VERSION_ghc(8,2,0) intToNumber :: Int# -> Natural @@ -195,6 +197,31 @@ instance (KnownNat x, KnownNat y, 2 <= x, 1 <= y) => KnownNat2 $(nameToSymbol '' _ | isTrue# (z1 ==# z2) -> SNatKn (intToNumber (z1 +# 1#)) | otherwise -> SNatKn (intToNumber z1) +-- | Extended version of 'CLog', which is also well-defined in case the non-base argument is zero. The additional third argument argument is returned in this particular case. dThe particular value is chosen the user. +-- +-- Note that additional equations are provided by the type-checker plugin solver +-- "GHC.TypeLits.Extra.Solver". +type family CLogWZ (base :: Nat) (value :: Nat) (ifzero :: Nat) :: Nat where + CLogWZ 2 0 z = z + CLogWZ 2 1 _ = 0 -- Additional equations are provided by the custom solver + +#if MIN_VERSION_ghc(9,4,0) +instance (KnownNat x, KnownNat y, KnownNat z, (2 <= x) ~ (() :: Constraint)) => KnownNat3 $(nameToSymbol ''CLogWZ) x y z where +#else +instance (KnownNat x, KnownNat y, KnownNat z, 2 <= x) => KnownNat3 $(nameToSymbol ''CLogWZ) x y z where +#endif + natSing3 = let x = natVal (Proxy @x) + y = natVal (Proxy @y) + z = natVal (Proxy @z) + z1 = integerLogBase# x y + z2 = integerLogBase# x (y-1) + in case y of + 0 -> SNatKn $ fromInteger z + 1 -> SNatKn 0 + _ | isTrue# (z1 ==# z2) -> SNatKn (intToNumber (z1 +# 1#)) + | otherwise -> SNatKn (intToNumber z1) + + -- | Type-level equivalent of -- where the operation only reduces when: -- diff --git a/src/GHC/TypeLits/Extra/Solver/Operations.hs b/src/GHC/TypeLits/Extra/Solver/Operations.hs index b6041ac..693ffe6 100644 --- a/src/GHC/TypeLits/Extra/Solver/Operations.hs +++ b/src/GHC/TypeLits/Extra/Solver/Operations.hs @@ -21,6 +21,7 @@ module GHC.TypeLits.Extra.Solver.Operations , mergeMod , mergeFLog , mergeCLog + , mergeCLogWZ , mergeLog , mergeGCD , mergeLCM @@ -83,6 +84,7 @@ data ExtraOp | GCD ExtraOp ExtraOp | LCM ExtraOp ExtraOp | Exp ExtraOp ExtraOp + | CLogWZ ExtraOp ExtraOp ExtraOp deriving Eq instance Outputable ExtraOp where @@ -99,6 +101,12 @@ instance Outputable ExtraOp where ppr (GCD x y) = text "GCD (" <+> ppr x <+> text "," <+> ppr y <+> text ")" ppr (LCM x y) = text "GCD (" <+> ppr x <+> text "," <+> ppr y <+> text ")" ppr (Exp x y) = text "Exp (" <+> ppr x <+> text "," <+> ppr y <+> text ")" + ppr (CLogWZ x y z) = + text "CLogWZ " + <+> text "(" <+> ppr x + <+> text "," <+> ppr y + <+> text "," <+> ppr z + <+> text ")" data ExtraDefs = ExtraDefs { maxTyCon :: TyCon @@ -112,6 +120,7 @@ data ExtraDefs = ExtraDefs , lcmTyCon :: TyCon , ordTyCon :: TyCon , assertTC :: TyCon + , clogWZTyCon :: TyCon } reifyEOP :: ExtraDefs -> ExtraOp -> Type @@ -128,6 +137,8 @@ reifyEOP defs (Mod x y) = mkTyConApp (modTyCon defs) [reifyEOP defs x ,reifyEOP defs y] reifyEOP defs (CLog x y) = mkTyConApp (clogTyCon defs) [reifyEOP defs x ,reifyEOP defs y] +reifyEOP defs (CLogWZ x y z) = mkTyConApp (clogTyCon defs) + $ reifyEOP defs <$> [x, y, z] reifyEOP defs (FLog x y) = mkTyConApp (flogTyCon defs) [reifyEOP defs x ,reifyEOP defs y] reifyEOP defs (Log x y) = mkTyConApp (logTyCon defs) [reifyEOP defs x @@ -195,6 +206,14 @@ mergeCLog i (Exp j k) | i == j = Just (k, Normalised) mergeCLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (clogBase i j) mergeCLog x y = Just (CLog x y, Untouched) +mergeCLogWZ :: ExtraOp -> ExtraOp -> ExtraOp -> Maybe NormaliseResult +mergeCLogWZ (I i) _ _ | i < 2 = Nothing +mergeCLogWZ _ (I 0) z = Just (z, Normalised) +mergeCLogWZ i (Exp j k) _ | i == j = Just (k, Normalised) +mergeCLogWZ x y@(I _) _ = do (res, _) <- mergeCLog x y + pure (res, Normalised) +mergeCLogWZ x y z = Just (CLogWZ x y z, Untouched) + mergeLog :: ExtraOp -> ExtraOp -> Maybe NormaliseResult mergeLog (I i) _ | i < 2 = Nothing mergeLog b (Exp b' y) | b == b' = Just (y, Normalised) diff --git a/src/GHC/TypeLits/Extra/Solver/Unify.hs b/src/GHC/TypeLits/Extra/Solver/Unify.hs index b365044..9f18571 100644 --- a/src/GHC/TypeLits/Extra/Solver/Unify.hs +++ b/src/GHC/TypeLits/Extra/Solver/Unify.hs @@ -97,6 +97,14 @@ normaliseNat defs (TyConApp tc [x,y]) (normaliseNat defs x) (normaliseNat defs y) +normaliseNat defs (TyConApp tc [x,y,z]) + | tc == clogWZTyCon defs = do + (x', n1) <- normaliseNat defs x + (y', n2) <- normaliseNat defs y + (z', n3) <- normaliseNat defs z + (res, n4) <- MaybeT $ return $ mergeCLogWZ x' y' z' + pure (res, foldl mergeNormalised Untouched [n1,n2,n3,n4]) + normaliseNat defs (TyConApp tc tys) = do let mergeExtraOp [] = [] mergeExtraOp ((Just (op, Normalised), _):xs) = reifyEOP defs op:mergeExtraOp xs @@ -162,6 +170,10 @@ fvOP (Log x y) = fvOP x `unionUniqSets` fvOP y fvOP (GCD x y) = fvOP x `unionUniqSets` fvOP y fvOP (LCM x y) = fvOP x `unionUniqSets` fvOP y fvOP (Exp x y) = fvOP x `unionUniqSets` fvOP y +fvOP (CLogWZ x y z) = + fvOP x `unionUniqSets` + fvOP y `unionUniqSets` + fvOP z eqFV :: ExtraOp -> ExtraOp -> Bool eqFV = (==) `on` fvOP @@ -180,3 +192,4 @@ containsConstants (Log x y) = containsConstants x || containsConstants y containsConstants (GCD x y) = containsConstants x || containsConstants y containsConstants (LCM x y) = containsConstants x || containsConstants y containsConstants (Exp x y) = containsConstants x || containsConstants y +containsConstants (CLogWZ x y z) = or $ map containsConstants [x, y, z] diff --git a/tests/Main.hs b/tests/Main.hs index d260275..8c93962 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -234,6 +234,36 @@ test58b -> Proxy (Max (n+2) 1) test58b = test58a +test59 :: Proxy (CLogWZ 3 10 9) -> Proxy 3 +test59 = id + +test60 :: Proxy ((CLogWZ 3 10 3) + x) -> Proxy (x + (CLogWZ 2 7 8)) +test60 = id + +test61 :: Proxy (CLogWZ x (x^y) 8) -> Proxy y +test61 = id + +test62 :: Integer +test62 = natVal (Proxy :: Proxy (CLogWZ 6 8 3)) + +test63 :: Integer +test63 = natVal (Proxy :: Proxy (CLogWZ 3 10 9)) + +test64 :: Integer +test64 = natVal (Proxy :: Proxy ((CLogWZ 2 4 11) * (3 ^ (CLogWZ 2 4 8)))) + +test65 :: Integer +test65 = natVal (Proxy :: Proxy (Max (CLogWZ 2 4 8) (CLogWZ 4 20 5))) + +test66 :: Proxy (CLogWZ 3 0 8) -> Proxy 8 +test66 = id + +test67 :: Proxy (CLogWZ 2 0 x) -> Proxy x +test67 = id + +test68 :: Proxy (CLogWZ 5 0 0) -> Proxy 0 +test68 = id + main :: IO () main = defaultMain tests @@ -411,6 +441,36 @@ tests = testGroup "ghc-typelits-natnormalise" , testCase "forall n p . n + 1 <= Max (n + p + 1) p" $ show (test57 Proxy Proxy Proxy) @?= "Proxy" + , testCase "CLogWZ 3 10 9 ~ 3" $ + show (test59 Proxy) @?= + "Proxy" + , testCase "forall x . CLogWZ 3 10 3 + x ~ x + CLogWZ 2 7 8" $ + show (test60 Proxy) @?= + "Proxy" + , testCase "forall x>1 . CLogWZ x (x^y) 8 ~ y" $ + show (test61 Proxy) @?= + "Proxy" + , testCase "KnownNat (CLogWZ 6 8 3) ~ 2" $ + show test62 @?= + "2" + , testCase "KnownNat (CLogWZ 3 10 9) ~ 3" $ + show test63 @?= + "3" + , testCase "KnownNat ((CLogWZ 2 4 11) * (3 ^ (CLogWZ 2 4 8)))) ~ 18" $ + show test64 @?= + "18" + , testCase "KnownNat (Max (CLogWZ 2 4 8) (CLogWZ 4 20 5)) ~ 3" $ + show test65 @?= + "3" + , testCase "CLogWZ 3 0 8 ~ 8" $ + show (test66 Proxy) @?= + "Proxy" + , testCase "forall x. CLogWZ 2 0 x ~ x" $ + show (test67 Proxy) @?= + "Proxy" + , testCase "CLogWZ 5 0 0 ~ 0" $ + show (test68 Proxy) @?= + "Proxy" ] , testGroup "errors" [ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors