From 785d2c66422dfa729aed27210fa45800bb95e6c9 Mon Sep 17 00:00:00 2001 From: Drazen Popovic Date: Wed, 10 Jul 2024 17:44:02 +0200 Subject: [PATCH] Adds phantom handling utilities to the Compiler with tests --- .../lambda-buffers-compiler.cabal | 1 + .../src/LambdaBuffers/Compiler/LamTy.hs | 2 +- .../src/LambdaBuffers/Compiler/LamTy/Eval.hs | 19 +++++++++- .../src/LambdaBuffers/ProtoCompat/Utils.hs | 36 +++++++++++++++++- lambda-buffers-compiler/test/Test.hs | 2 + .../test/Test/LambdaBuffers/Compiler/LamTy.hs | 38 ++++++++++--------- .../Test/LambdaBuffers/Compiler/Phantoms.hs | 28 ++++++++++++++ .../LambdaBuffers/Compiler/TypeClassCheck.hs | 2 +- .../Test/LambdaBuffers/ProtoCompat/Utils.hs | 24 ++++++++++++ 9 files changed, 131 insertions(+), 21 deletions(-) create mode 100644 lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/Phantoms.hs diff --git a/lambda-buffers-compiler/lambda-buffers-compiler.cabal b/lambda-buffers-compiler/lambda-buffers-compiler.cabal index c8d2fe12..00f3a2f8 100644 --- a/lambda-buffers-compiler/lambda-buffers-compiler.cabal +++ b/lambda-buffers-compiler/lambda-buffers-compiler.cabal @@ -199,6 +199,7 @@ test-suite tests Test.LambdaBuffers.Compiler.LamTy Test.LambdaBuffers.Compiler.MiniLog Test.LambdaBuffers.Compiler.Mutation + Test.LambdaBuffers.Compiler.Phantoms Test.LambdaBuffers.Compiler.TypeClassCheck Test.LambdaBuffers.Compiler.Utils Test.LambdaBuffers.Compiler.Utils.Golden diff --git a/lambda-buffers-compiler/src/LambdaBuffers/Compiler/LamTy.hs b/lambda-buffers-compiler/src/LambdaBuffers/Compiler/LamTy.hs index 03489ac3..0579e107 100644 --- a/lambda-buffers-compiler/src/LambdaBuffers/Compiler/LamTy.hs +++ b/lambda-buffers-compiler/src/LambdaBuffers/Compiler/LamTy.hs @@ -1,4 +1,4 @@ -module LambdaBuffers.Compiler.LamTy (LT.Ty (..), LT.fromTy, LT.eval, LT.runEval, LT.runEval', LT.prettyTy) where +module LambdaBuffers.Compiler.LamTy (LT.Ty (..), LT.fromTy, LT.eval, LT.runEval, LT.runEval', LT.prettyTy, LT.runEvalWithGas) where import LambdaBuffers.Compiler.LamTy.Eval qualified as LT import LambdaBuffers.Compiler.LamTy.Pretty qualified as LT diff --git a/lambda-buffers-compiler/src/LambdaBuffers/Compiler/LamTy/Eval.hs b/lambda-buffers-compiler/src/LambdaBuffers/Compiler/LamTy/Eval.hs index 6b07e8db..977b9d44 100644 --- a/lambda-buffers-compiler/src/LambdaBuffers/Compiler/LamTy/Eval.hs +++ b/lambda-buffers-compiler/src/LambdaBuffers/Compiler/LamTy/Eval.hs @@ -1,4 +1,4 @@ -module LambdaBuffers.Compiler.LamTy.Eval (runEval, runEval', eval) where +module LambdaBuffers.Compiler.LamTy.Eval (runEval, runEval', eval, runEvalWithGas) where import Control.Lens (makeLenses, view, (%~), (&), (.~), (^.)) import Control.Monad.Error.Class (MonadError (throwError)) @@ -35,6 +35,23 @@ runEval' mn tds ty = let p = runReaderT (eval ty) (MkContext mn tds mempty) in runExcept p +-- | `runEvalWithGas someGas ci mn ty` evaluates a `ty` repeatedly until the gas is reached or no further change (fixpoint reached) +runEvalWithGas :: Maybe Int -> PC.CompilerInput -> PC.ModuleName -> PC.Ty -> Either P.Error Ty +runEvalWithGas mayGas ci mn ty = + let tydefs = PC.indexTyDefs ci + in fix mayGas (runEval' mn tydefs) (fromTy ty) + +fix :: Maybe Int -> (Ty -> Either e Ty) -> Ty -> Either e Ty +fix Nothing r x = case r x of + Left err -> Left err + Right x' -> if x == x' then Right x else fix Nothing r x' +fix (Just n) r x = + if n <= 0 + then Right x + else case r x of + Left err -> Left err + Right x' -> if x == x' then Right x else fix (Just (n - 1)) r x' + eval :: MonadEval m => Ty -> m Ty eval (TyRef tr) = eval (TyApp (TyRef tr) [] Nothing) eval (TyApp (TyRef tr) args t) = resolveTyRef tr >>= (\f -> eval $ TyApp f args t) diff --git a/lambda-buffers-compiler/src/LambdaBuffers/ProtoCompat/Utils.hs b/lambda-buffers-compiler/src/LambdaBuffers/ProtoCompat/Utils.hs index 526f0635..43b0bb6a 100644 --- a/lambda-buffers-compiler/src/LambdaBuffers/ProtoCompat/Utils.hs +++ b/lambda-buffers-compiler/src/LambdaBuffers/ProtoCompat/Utils.hs @@ -1,14 +1,18 @@ {-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -Wno-redundant-constraints #-} -module LambdaBuffers.ProtoCompat.Utils (prettyModuleName, prettyModuleName', localRef2ForeignRef, classClosure, filterClassInModule) where +module LambdaBuffers.ProtoCompat.Utils (prettyModuleName, prettyModuleName', localRef2ForeignRef, classClosure, filterClassInModule, collectTyVars, collectVars, collectPhantomTyArgs) where import Control.Lens (Getter, to, view, (&), (.~), (^.)) +import Data.Foldable (Foldable (toList)) import Data.Map qualified as Map +import Data.Map.Ordered (OMap) +import Data.Map.Ordered qualified as OMap import Data.ProtoLens (Message (defMessage)) import Data.Set (Set) import Data.Set qualified as Set import LambdaBuffers.ProtoCompat.Indexing qualified as PC +import LambdaBuffers.ProtoCompat.InfoLess qualified as PC import LambdaBuffers.ProtoCompat.IsCompat.FromProto qualified as PC import LambdaBuffers.ProtoCompat.IsCompat.Lang () import LambdaBuffers.ProtoCompat.Types qualified as PC @@ -89,3 +93,33 @@ filterDerive cls m drv = filterConstraint cls m (drv ^. #constraint) filterConstraint :: Set PC.QClassName -> PC.Module -> PC.Constraint -> Bool filterConstraint cls m cnstr = PC.qualifyClassRef (m ^. #moduleName) (cnstr ^. #classRef) `Set.member` cls + +-- | `collectTyVars ty` scans the `PC.Ty` expression and collects all the type variables +collectTyVars :: PC.Ty -> [PC.Ty] +collectTyVars = fmap (`PC.withInfoLess` (PC.TyVarI . PC.TyVar)) . Set.toList . collectVars + +-- | `collectVars ty` is similar to `collectTyVars` but returns type variable names +collectVars :: PC.Ty -> Set (PC.InfoLess PC.VarName) +collectVars = collectVars' mempty + +collectVars' :: Set (PC.InfoLess PC.VarName) -> PC.Ty -> Set (PC.InfoLess PC.VarName) +collectVars' collected (PC.TyVarI tv) = Set.insert (PC.mkInfoLess . view #varName $ tv) collected +collectVars' collected (PC.TyAppI (PC.TyApp _ args _)) = collected `Set.union` (Set.unions . fmap collectVars $ args) +collectVars' collected _ = collected + +collectPhantomTyArgs :: PC.TyDef -> [PC.TyArg] +collectPhantomTyArgs tyDef = + let + PC.TyAbs tyArgs tyBody _si = PC.tyAbs tyDef + tys :: [PC.Ty] = go [] tyBody tyArgs + vars = Set.unions $ collectVars <$> tys + args = Set.fromList [varName | (varName, _) <- OMap.assocs tyArgs] + phantomArgs = Set.difference args vars + in + [tyArg | (varName, tyArg) <- OMap.assocs tyArgs, varName `Set.member` phantomArgs] + where + go :: [PC.Ty] -> PC.TyBody -> OMap (PC.InfoLess PC.VarName) PC.TyArg -> [PC.Ty] + go tys (PC.SumI (PC.Sum ctors _si)) tyArgs = tys <> mconcat [go tys (PC.ProductI $ PC.product ctor) tyArgs | ctor <- toList ctors] + go tys (PC.ProductI (PC.Product fields _si)) _tyArgs = tys <> toList fields + go tys (PC.RecordI (PC.Record fields _si)) _tyArgs = tys <> [PC.fieldTy field | field <- toList fields] + go tys (PC.OpaqueI _) tyArgs = tys <> [PC.TyVarI . PC.TyVar . PC.argName $ tyArg | (_, tyArg) <- OMap.assocs tyArgs] diff --git a/lambda-buffers-compiler/test/Test.hs b/lambda-buffers-compiler/test/Test.hs index cf76ecd6..482e43d3 100644 --- a/lambda-buffers-compiler/test/Test.hs +++ b/lambda-buffers-compiler/test/Test.hs @@ -6,6 +6,7 @@ import Test.LambdaBuffers.Compiler qualified as LBC import Test.LambdaBuffers.Compiler.ClassClosure qualified as ClassClosure import Test.LambdaBuffers.Compiler.LamTy qualified as LT import Test.LambdaBuffers.Compiler.MiniLog qualified as ML +import Test.LambdaBuffers.Compiler.Phantoms qualified as Phantoms import Test.LambdaBuffers.Compiler.TypeClassCheck qualified as TC import Test.Tasty (defaultMain, testGroup) @@ -21,4 +22,5 @@ main = , ML.test , TC.test , ClassClosure.tests + , Phantoms.test ] diff --git a/lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/LamTy.hs b/lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/LamTy.hs index 0f5fa3f3..f402f7e2 100644 --- a/lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/LamTy.hs +++ b/lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/LamTy.hs @@ -105,6 +105,21 @@ test = (Just 1) (U.lr "Beer" U.@ [U.fr ["Prelude"] "Text"]) "Prelude.Int8 Prelude.Text" + , fooCiTestCase + "Phantom a 1-> Prelude.Int8" + (Just 1) + (U.lr "Phantom" U.@ [U.tv "a"]) + "Prelude.Int8" + , fooCiTestCase + "Phantom a *-> opq" + Nothing + (U.lr "Phantom" U.@ [U.tv "a"]) + "opq" + , fooCiTestCase + "RecursiveA a 1-> opq" + (Just 1) + (U.lr "RecursiveA" U.@ [U.tv "a"]) + "Prelude.Int8 (RecursiveA a)" ] fooCi :: PC.CompilerInput @@ -125,6 +140,8 @@ fooCi = ] ) , U.td "Beer" (U.abs ["a"] $ U.prod' [U.fr ["Prelude"] "Int8", U.tv "a"]) + , U.td "Phantom" (U.abs ["a"] $ U.prod' [U.fr ["Prelude"] "Int8"]) + , U.td "RecursiveA" (U.abs ["a"] $ U.prod' [U.fr ["Prelude"] "Int8", U.lr "RecursiveA" U.@ [U.tv "a"]]) , U.td'maybe , U.td'either , U.td'list @@ -135,20 +152,7 @@ fooCiTestCase :: TestName -> Maybe Int -> PC.Ty -> String -> TestTree fooCiTestCase title mayGas ty want = testCase title $ runTestFix mayGas fooCi (U.mn ["Foo"]) ty want runTestFix :: Maybe Int -> PC.CompilerInput -> PC.ModuleName -> PC.Ty -> String -> Assertion -runTestFix mayGas ci mn ty want = - let tydefs = PC.indexTyDefs ci - in case fix mayGas (LT.runEval' mn tydefs) (LT.fromTy ty) of - Left err -> assertFailure (show err) - Right res -> do - assertEqual "" want (show res) - -fix :: Maybe Int -> (LT.Ty -> Either e LT.Ty) -> LT.Ty -> Either e LT.Ty -fix Nothing r x = case r x of - Left err -> Left err - Right x' -> if x == x' then Right x else fix Nothing r x' -fix (Just n) r x = - if n <= 0 - then Right x - else case r x of - Left err -> Left err - Right x' -> if x == x' then Right x else fix (Just (n - 1)) r x' +runTestFix mayGas ci mn ty want = case LT.runEvalWithGas mayGas ci mn ty of + Left err -> assertFailure (show err) + Right res -> do + assertEqual "" want (show res) diff --git a/lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/Phantoms.hs b/lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/Phantoms.hs new file mode 100644 index 00000000..295026d2 --- /dev/null +++ b/lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/Phantoms.hs @@ -0,0 +1,28 @@ +module Test.LambdaBuffers.Compiler.Phantoms (test) where + +import LambdaBuffers.ProtoCompat qualified as PC +import LambdaBuffers.ProtoCompat.Utils (collectPhantomTyArgs) +import Test.LambdaBuffers.ProtoCompat.Utils qualified as U +import Test.Tasty (TestTree, adjustOption, testGroup) +import Test.Tasty.HUnit (testCase, (@?=)) +import Test.Tasty.Hedgehog qualified as H + +test :: TestTree +test = + adjustOption (\_ -> H.HedgehogTestLimit $ Just 1000) $ + testGroup + "Phantoms checks" + [ testCase "phantomA" $ phantoms U.td'phantomA @?= ["a"] + , testCase "phantomB" $ phantoms U.td'phantomB @?= ["a", "b"] + , testCase "phantomC" $ phantoms U.td'phantomC @?= ["a"] + , testCase "phantomD" $ phantoms U.td'phantomD @?= ["a", "b"] + , testCase "phantomE" $ phantoms U.td'phantomE @?= ["a"] + , testCase "phantomF" $ phantoms U.td'phantomF @?= ["a", "b"] + , testCase "Either" $ phantoms U.td'either @?= [] + , testCase "Either opaque" $ phantoms U.td'eitherO @?= [] + , testCase "List" $ phantoms U.td'list @?= [] + , testCase "Maybe" $ phantoms U.td'maybe @?= [] + , testCase "Maybe opaque" $ phantoms U.td'maybeO @?= [] + ] + where + phantoms = fmap ((\(PC.VarName name _) -> name) . PC.argName) . collectPhantomTyArgs diff --git a/lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/TypeClassCheck.hs b/lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/TypeClassCheck.hs index 4d9b246e..703d2783 100644 --- a/lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/TypeClassCheck.hs +++ b/lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/TypeClassCheck.hs @@ -762,7 +762,7 @@ fails :: TestName -> PC.CompilerInput -> TestTree fails title ci = Golden.fails goldensDir - (\tdir -> let fn = tdir "compiler_error" <.> "textproto" in (,) <$> (PbText.readMessageOrDie <$> Text.readFile fn) <*> pure fn) + (\tdir -> let fn = tdir "compiler_error" <.> "textproto" in ((,) . PbText.readMessageOrDie <$> Text.readFile fn) <*> pure fn) (\otherFn gotErr -> Text.writeFile otherFn (Text.pack . show $ PbText.pprintMessage gotErr)) title (fst $ TC.runCheck' ci) diff --git a/lambda-buffers-compiler/test/Test/LambdaBuffers/ProtoCompat/Utils.hs b/lambda-buffers-compiler/test/Test/LambdaBuffers/ProtoCompat/Utils.hs index 325bb7f7..da4a2343 100644 --- a/lambda-buffers-compiler/test/Test/LambdaBuffers/ProtoCompat/Utils.hs +++ b/lambda-buffers-compiler/test/Test/LambdaBuffers/ProtoCompat/Utils.hs @@ -29,6 +29,12 @@ module Test.LambdaBuffers.ProtoCompat.Utils ( qcln', mod'prelude'only'eq, mod'prelude'noclass, + td'phantomA, + td'phantomB, + td'phantomC, + td'phantomD, + td'phantomE, + td'phantomF, ) where import Control.Lens ((^.)) @@ -182,6 +188,24 @@ td'list = td "List" (abs ["a"] $ sum [("Nil", []), ("Cons", [tv "a", lr "List" @ td'listO :: PC.TyDef td'listO = td "List" (abs ["a"] opq) +td'phantomA :: PC.TyDef +td'phantomA = td "PhantomA" (abs ["a"] $ sum [("A", [lr "Int", lr "Int"]), ("B", [lr "Int", lr "Int"])]) + +td'phantomB :: PC.TyDef +td'phantomB = td "PhantomB" (abs ["a", "b"] $ sum [("A", [lr "Int", lr "Int"]), ("B", [lr "Int", lr "Int"])]) + +td'phantomC :: PC.TyDef +td'phantomC = td "PhantomC" (abs ["a"] $ prod' [lr "Int", lr "Int"]) + +td'phantomD :: PC.TyDef +td'phantomD = td "PhantomD" (abs ["a", "b"] $ prod' [lr "Int", lr "Int"]) + +td'phantomE :: PC.TyDef +td'phantomE = td "PhantomE" (abs ["a"] $ recrd [("foo", lr "Int"), ("bar", lr "Int")]) + +td'phantomF :: PC.TyDef +td'phantomF = td "PhantomF" (abs ["a", "b"] $ recrd [("foo", lr "Int"), ("bar", lr "Int")]) + -- | Some class definitions. cd'eq :: PC.ClassDef cd'eq = classDef ("Eq", "a") []