Skip to content

Commit

Permalink
Adds phantom handling utilities to the Compiler with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bladyjoker committed Jul 10, 2024
1 parent 938f167 commit 785d2c6
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 21 deletions.
1 change: 1 addition & 0 deletions lambda-buffers-compiler/lambda-buffers-compiler.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ test-suite tests
Test.LambdaBuffers.Compiler.LamTy
Test.LambdaBuffers.Compiler.MiniLog
Test.LambdaBuffers.Compiler.Mutation
Test.LambdaBuffers.Compiler.Phantoms
Test.LambdaBuffers.Compiler.TypeClassCheck
Test.LambdaBuffers.Compiler.Utils
Test.LambdaBuffers.Compiler.Utils.Golden
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module LambdaBuffers.Compiler.LamTy (LT.Ty (..), LT.fromTy, LT.eval, LT.runEval, LT.runEval', LT.prettyTy) where
module LambdaBuffers.Compiler.LamTy (LT.Ty (..), LT.fromTy, LT.eval, LT.runEval, LT.runEval', LT.prettyTy, LT.runEvalWithGas) where

import LambdaBuffers.Compiler.LamTy.Eval qualified as LT
import LambdaBuffers.Compiler.LamTy.Pretty qualified as LT
Expand Down
19 changes: 18 additions & 1 deletion lambda-buffers-compiler/src/LambdaBuffers/Compiler/LamTy/Eval.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module LambdaBuffers.Compiler.LamTy.Eval (runEval, runEval', eval) where
module LambdaBuffers.Compiler.LamTy.Eval (runEval, runEval', eval, runEvalWithGas) where

import Control.Lens (makeLenses, view, (%~), (&), (.~), (^.))
import Control.Monad.Error.Class (MonadError (throwError))
Expand Down Expand Up @@ -35,6 +35,23 @@ runEval' mn tds ty =
let p = runReaderT (eval ty) (MkContext mn tds mempty)
in runExcept p

-- | `runEvalWithGas someGas ci mn ty` evaluates a `ty` repeatedly until the gas is reached or no further change (fixpoint reached)
runEvalWithGas :: Maybe Int -> PC.CompilerInput -> PC.ModuleName -> PC.Ty -> Either P.Error Ty
runEvalWithGas mayGas ci mn ty =
let tydefs = PC.indexTyDefs ci
in fix mayGas (runEval' mn tydefs) (fromTy ty)

fix :: Maybe Int -> (Ty -> Either e Ty) -> Ty -> Either e Ty
fix Nothing r x = case r x of
Left err -> Left err
Right x' -> if x == x' then Right x else fix Nothing r x'
fix (Just n) r x =
if n <= 0
then Right x
else case r x of
Left err -> Left err
Right x' -> if x == x' then Right x else fix (Just (n - 1)) r x'

eval :: MonadEval m => Ty -> m Ty
eval (TyRef tr) = eval (TyApp (TyRef tr) [] Nothing)
eval (TyApp (TyRef tr) args t) = resolveTyRef tr >>= (\f -> eval $ TyApp f args t)
Expand Down
36 changes: 35 additions & 1 deletion lambda-buffers-compiler/src/LambdaBuffers/ProtoCompat/Utils.hs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
{-# OPTIONS_GHC -Wno-orphans #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

module LambdaBuffers.ProtoCompat.Utils (prettyModuleName, prettyModuleName', localRef2ForeignRef, classClosure, filterClassInModule) where
module LambdaBuffers.ProtoCompat.Utils (prettyModuleName, prettyModuleName', localRef2ForeignRef, classClosure, filterClassInModule, collectTyVars, collectVars, collectPhantomTyArgs) where

import Control.Lens (Getter, to, view, (&), (.~), (^.))
import Data.Foldable (Foldable (toList))
import Data.Map qualified as Map
import Data.Map.Ordered (OMap)
import Data.Map.Ordered qualified as OMap
import Data.ProtoLens (Message (defMessage))
import Data.Set (Set)
import Data.Set qualified as Set
import LambdaBuffers.ProtoCompat.Indexing qualified as PC
import LambdaBuffers.ProtoCompat.InfoLess qualified as PC
import LambdaBuffers.ProtoCompat.IsCompat.FromProto qualified as PC
import LambdaBuffers.ProtoCompat.IsCompat.Lang ()
import LambdaBuffers.ProtoCompat.Types qualified as PC
Expand Down Expand Up @@ -89,3 +93,33 @@ filterDerive cls m drv = filterConstraint cls m (drv ^. #constraint)

filterConstraint :: Set PC.QClassName -> PC.Module -> PC.Constraint -> Bool
filterConstraint cls m cnstr = PC.qualifyClassRef (m ^. #moduleName) (cnstr ^. #classRef) `Set.member` cls

-- | `collectTyVars ty` scans the `PC.Ty` expression and collects all the type variables
collectTyVars :: PC.Ty -> [PC.Ty]
collectTyVars = fmap (`PC.withInfoLess` (PC.TyVarI . PC.TyVar)) . Set.toList . collectVars

-- | `collectVars ty` is similar to `collectTyVars` but returns type variable names
collectVars :: PC.Ty -> Set (PC.InfoLess PC.VarName)
collectVars = collectVars' mempty

collectVars' :: Set (PC.InfoLess PC.VarName) -> PC.Ty -> Set (PC.InfoLess PC.VarName)
collectVars' collected (PC.TyVarI tv) = Set.insert (PC.mkInfoLess . view #varName $ tv) collected
collectVars' collected (PC.TyAppI (PC.TyApp _ args _)) = collected `Set.union` (Set.unions . fmap collectVars $ args)
collectVars' collected _ = collected

collectPhantomTyArgs :: PC.TyDef -> [PC.TyArg]
collectPhantomTyArgs tyDef =
let
PC.TyAbs tyArgs tyBody _si = PC.tyAbs tyDef
tys :: [PC.Ty] = go [] tyBody tyArgs
vars = Set.unions $ collectVars <$> tys
args = Set.fromList [varName | (varName, _) <- OMap.assocs tyArgs]
phantomArgs = Set.difference args vars
in
[tyArg | (varName, tyArg) <- OMap.assocs tyArgs, varName `Set.member` phantomArgs]
where
go :: [PC.Ty] -> PC.TyBody -> OMap (PC.InfoLess PC.VarName) PC.TyArg -> [PC.Ty]
go tys (PC.SumI (PC.Sum ctors _si)) tyArgs = tys <> mconcat [go tys (PC.ProductI $ PC.product ctor) tyArgs | ctor <- toList ctors]
go tys (PC.ProductI (PC.Product fields _si)) _tyArgs = tys <> toList fields
go tys (PC.RecordI (PC.Record fields _si)) _tyArgs = tys <> [PC.fieldTy field | field <- toList fields]
go tys (PC.OpaqueI _) tyArgs = tys <> [PC.TyVarI . PC.TyVar . PC.argName $ tyArg | (_, tyArg) <- OMap.assocs tyArgs]
2 changes: 2 additions & 0 deletions lambda-buffers-compiler/test/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Test.LambdaBuffers.Compiler qualified as LBC
import Test.LambdaBuffers.Compiler.ClassClosure qualified as ClassClosure
import Test.LambdaBuffers.Compiler.LamTy qualified as LT
import Test.LambdaBuffers.Compiler.MiniLog qualified as ML
import Test.LambdaBuffers.Compiler.Phantoms qualified as Phantoms
import Test.LambdaBuffers.Compiler.TypeClassCheck qualified as TC
import Test.Tasty (defaultMain, testGroup)

Expand All @@ -21,4 +22,5 @@ main =
, ML.test
, TC.test
, ClassClosure.tests
, Phantoms.test
]
38 changes: 21 additions & 17 deletions lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/LamTy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ test =
(Just 1)
(U.lr "Beer" U.@ [U.fr ["Prelude"] "Text"])
"Prelude.Int8 Prelude.Text"
, fooCiTestCase
"Phantom a 1-> Prelude.Int8"
(Just 1)
(U.lr "Phantom" U.@ [U.tv "a"])
"Prelude.Int8"
, fooCiTestCase
"Phantom a *-> opq"
Nothing
(U.lr "Phantom" U.@ [U.tv "a"])
"opq"
, fooCiTestCase
"RecursiveA a 1-> opq"
(Just 1)
(U.lr "RecursiveA" U.@ [U.tv "a"])
"Prelude.Int8 (RecursiveA a)"
]

fooCi :: PC.CompilerInput
Expand All @@ -125,6 +140,8 @@ fooCi =
]
)
, U.td "Beer" (U.abs ["a"] $ U.prod' [U.fr ["Prelude"] "Int8", U.tv "a"])
, U.td "Phantom" (U.abs ["a"] $ U.prod' [U.fr ["Prelude"] "Int8"])
, U.td "RecursiveA" (U.abs ["a"] $ U.prod' [U.fr ["Prelude"] "Int8", U.lr "RecursiveA" U.@ [U.tv "a"]])
, U.td'maybe
, U.td'either
, U.td'list
Expand All @@ -135,20 +152,7 @@ fooCiTestCase :: TestName -> Maybe Int -> PC.Ty -> String -> TestTree
fooCiTestCase title mayGas ty want = testCase title $ runTestFix mayGas fooCi (U.mn ["Foo"]) ty want

runTestFix :: Maybe Int -> PC.CompilerInput -> PC.ModuleName -> PC.Ty -> String -> Assertion
runTestFix mayGas ci mn ty want =
let tydefs = PC.indexTyDefs ci
in case fix mayGas (LT.runEval' mn tydefs) (LT.fromTy ty) of
Left err -> assertFailure (show err)
Right res -> do
assertEqual "" want (show res)

fix :: Maybe Int -> (LT.Ty -> Either e LT.Ty) -> LT.Ty -> Either e LT.Ty
fix Nothing r x = case r x of
Left err -> Left err
Right x' -> if x == x' then Right x else fix Nothing r x'
fix (Just n) r x =
if n <= 0
then Right x
else case r x of
Left err -> Left err
Right x' -> if x == x' then Right x else fix (Just (n - 1)) r x'
runTestFix mayGas ci mn ty want = case LT.runEvalWithGas mayGas ci mn ty of
Left err -> assertFailure (show err)
Right res -> do
assertEqual "" want (show res)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module Test.LambdaBuffers.Compiler.Phantoms (test) where

import LambdaBuffers.ProtoCompat qualified as PC
import LambdaBuffers.ProtoCompat.Utils (collectPhantomTyArgs)
import Test.LambdaBuffers.ProtoCompat.Utils qualified as U
import Test.Tasty (TestTree, adjustOption, testGroup)
import Test.Tasty.HUnit (testCase, (@?=))
import Test.Tasty.Hedgehog qualified as H

test :: TestTree
test =
adjustOption (\_ -> H.HedgehogTestLimit $ Just 1000) $
testGroup
"Phantoms checks"
[ testCase "phantomA" $ phantoms U.td'phantomA @?= ["a"]
, testCase "phantomB" $ phantoms U.td'phantomB @?= ["a", "b"]
, testCase "phantomC" $ phantoms U.td'phantomC @?= ["a"]
, testCase "phantomD" $ phantoms U.td'phantomD @?= ["a", "b"]
, testCase "phantomE" $ phantoms U.td'phantomE @?= ["a"]
, testCase "phantomF" $ phantoms U.td'phantomF @?= ["a", "b"]
, testCase "Either" $ phantoms U.td'either @?= []
, testCase "Either opaque" $ phantoms U.td'eitherO @?= []
, testCase "List" $ phantoms U.td'list @?= []
, testCase "Maybe" $ phantoms U.td'maybe @?= []
, testCase "Maybe opaque" $ phantoms U.td'maybeO @?= []
]
where
phantoms = fmap ((\(PC.VarName name _) -> name) . PC.argName) . collectPhantomTyArgs
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ fails :: TestName -> PC.CompilerInput -> TestTree
fails title ci =
Golden.fails
goldensDir
(\tdir -> let fn = tdir </> "compiler_error" <.> "textproto" in (,) <$> (PbText.readMessageOrDie <$> Text.readFile fn) <*> pure fn)
(\tdir -> let fn = tdir </> "compiler_error" <.> "textproto" in ((,) . PbText.readMessageOrDie <$> Text.readFile fn) <*> pure fn)
(\otherFn gotErr -> Text.writeFile otherFn (Text.pack . show $ PbText.pprintMessage gotErr))
title
(fst $ TC.runCheck' ci)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ module Test.LambdaBuffers.ProtoCompat.Utils (
qcln',
mod'prelude'only'eq,
mod'prelude'noclass,
td'phantomA,
td'phantomB,
td'phantomC,
td'phantomD,
td'phantomE,
td'phantomF,
) where

import Control.Lens ((^.))
Expand Down Expand Up @@ -182,6 +188,24 @@ td'list = td "List" (abs ["a"] $ sum [("Nil", []), ("Cons", [tv "a", lr "List" @
td'listO :: PC.TyDef
td'listO = td "List" (abs ["a"] opq)

td'phantomA :: PC.TyDef
td'phantomA = td "PhantomA" (abs ["a"] $ sum [("A", [lr "Int", lr "Int"]), ("B", [lr "Int", lr "Int"])])

td'phantomB :: PC.TyDef
td'phantomB = td "PhantomB" (abs ["a", "b"] $ sum [("A", [lr "Int", lr "Int"]), ("B", [lr "Int", lr "Int"])])

td'phantomC :: PC.TyDef
td'phantomC = td "PhantomC" (abs ["a"] $ prod' [lr "Int", lr "Int"])

td'phantomD :: PC.TyDef
td'phantomD = td "PhantomD" (abs ["a", "b"] $ prod' [lr "Int", lr "Int"])

td'phantomE :: PC.TyDef
td'phantomE = td "PhantomE" (abs ["a"] $ recrd [("foo", lr "Int"), ("bar", lr "Int")])

td'phantomF :: PC.TyDef
td'phantomF = td "PhantomF" (abs ["a", "b"] $ recrd [("foo", lr "Int"), ("bar", lr "Int")])

-- | Some class definitions.
cd'eq :: PC.ClassDef
cd'eq = classDef ("Eq", "a") []
Expand Down

0 comments on commit 785d2c6

Please sign in to comment.