Skip to content

Commit

Permalink
Merge pull request #239 from mlabs-haskell/bladyjoker/rust-phantoms
Browse files Browse the repository at this point in the history
Rust's phantom handling fixes
  • Loading branch information
bladyjoker authored Jul 30, 2024
2 parents 401f8a9 + a08502f commit 75f4411
Show file tree
Hide file tree
Showing 13 changed files with 142 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ printPEqInstanceDef :: MonadHaskellBackend t m => PC.Ty -> Doc ann -> m (Doc ann
printPEqInstanceDef ty implDefDoc = do
Print.importClass PlRefs.peqQClassName
Print.importClass PlRefs.pisDataQClassName
let freeVars = Haskell.collectTyVars ty
let freeVars = PC.collectTyVars ty
headDoc <- Haskell.printConstraint PlRefs.peqQClassName ty
case freeVars of
[] -> return $ "instance" <+> headDoc <+> "where" <> hardline <> space <> space <> implDefDoc
Expand Down Expand Up @@ -170,7 +170,7 @@ printPlutusTypeInstanceDef ty implDefDoc = do
Print.importType PlRefs.pdataQTyName
headDoc <- Haskell.printConstraint PlRefs.plutusTypeQClassName ty
tyDoc <- Haskell.printTyInner ty
let freeVars = Haskell.collectTyVars ty
let freeVars = PC.collectTyVars ty
pinnerDefDoc = "type PInner" <+> tyDoc <+> "=" <+> Haskell.printHsQTyName PlRefs.pdataQTyName
case freeVars of
[] ->
Expand Down Expand Up @@ -272,7 +272,7 @@ printPTryFromPAsDataInstanceDef ty implDefDoc = do
Haskell.printHsQClassName PlRefs.ptryFromQClassName
<+> Haskell.printHsQTyName PlRefs.pdataQTyName
<+> parens (Haskell.printHsQTyName PlRefs.pasDataQTyName <+> tyDoc)
freeVars = Haskell.collectTyVars ty
freeVars = PC.collectTyVars ty
pinnerDefDoc =
"type PTryFromExcess"
<+> Haskell.printHsQTyName PlRefs.pdataQTyName
Expand Down Expand Up @@ -332,7 +332,7 @@ printPTryFromInstanceDef ty = do
Haskell.printHsQClassName PlRefs.ptryFromQClassName
<+> Haskell.printHsQTyName PlRefs.pdataQTyName
<+> tyDoc
freeVars = Haskell.collectTyVars ty
freeVars = PC.collectTyVars ty

pinnerDefDoc =
"type PTryFromExcess"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
module LambdaBuffers.Codegen.Haskell.Print.InstanceDef (printInstanceDef, printConstraint, collectTyVars, printInstanceContext, printInstanceContext', printConstraint') where
module LambdaBuffers.Codegen.Haskell.Print.InstanceDef (printInstanceDef, printConstraint, printInstanceContext, printInstanceContext', printConstraint') where

import Control.Lens (view)
import Data.Foldable (Foldable (toList))
import Data.Set (Set)
import Data.Set qualified as Set
import LambdaBuffers.Codegen.Haskell.Backend (MonadHaskellBackend)
import LambdaBuffers.Codegen.Haskell.Print.Syntax qualified as HsSyntax
import LambdaBuffers.Codegen.Haskell.Print.TyDef (printTyInner)
Expand All @@ -23,7 +19,7 @@ instance (SomeClass a, SomeClass b, SomeClass c) => SomeClass (SomeTy a b c) whe
-}
printInstanceDef :: forall t m ann. MonadHaskellBackend t m => HsSyntax.QClassName -> PC.Ty -> m (Doc ann -> m (Doc ann))
printInstanceDef hsQClassName ty = do
let freeVars = collectTyVars ty
let freeVars = PC.collectTyVars ty
headDoc <- printConstraint hsQClassName ty
return $ case freeVars of
[] -> \implDoc -> return $ "instance" <+> headDoc <+> "where" <> hardline <> space <> space <> implDoc
Expand All @@ -47,14 +43,3 @@ printConstraint' qcn tys = do
let crefDoc = HsSyntax.printHsQClassName qcn
tyDocs <- traverse printTyInner tys
return $ crefDoc <+> hsep tyDocs

collectTyVars :: PC.Ty -> [PC.Ty]
collectTyVars = fmap (`PC.withInfoLess` (PC.TyVarI . PC.TyVar)) . toList . collectVars

collectVars :: PC.Ty -> Set (PC.InfoLess PC.VarName)
collectVars = collectVars' mempty

collectVars' :: Set (PC.InfoLess PC.VarName) -> PC.Ty -> Set (PC.InfoLess PC.VarName)
collectVars' collected (PC.TyVarI tv) = Set.insert (PC.mkInfoLess . view #varName $ tv) collected
collectVars' collected (PC.TyAppI (PC.TyApp _ args _)) = collected `Set.union` (Set.unions . fmap collectVars $ args)
collectVars' collected _ = collected
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ printRec parentTyN tyArgs (PC.Record fields _) = do
let iTyDefs = indexTyDefs ci
mn <- asks (view $ Print.ctxModule . #moduleName)
let phantomTyArgs = collectPhantomTyArgs iTyDefs mn parentTyN (recFieldTys fields) tyArgs
phantomFields = printPhantomDataField <$> phantomTyArgs
phantomFields = pub . printPhantomDataField <$> phantomTyArgs
if null fields && null phantomTyArgs
then return semi
else do
Expand Down
1 change: 1 addition & 0 deletions lambda-buffers-compiler/lambda-buffers-compiler.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ test-suite tests
Test.LambdaBuffers.Compiler.LamTy
Test.LambdaBuffers.Compiler.MiniLog
Test.LambdaBuffers.Compiler.Mutation
Test.LambdaBuffers.Compiler.Phantoms
Test.LambdaBuffers.Compiler.TypeClassCheck
Test.LambdaBuffers.Compiler.Utils
Test.LambdaBuffers.Compiler.Utils.Golden
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module LambdaBuffers.Compiler.LamTy (LT.Ty (..), LT.fromTy, LT.eval, LT.runEval, LT.runEval', LT.prettyTy) where
module LambdaBuffers.Compiler.LamTy (LT.Ty (..), LT.fromTy, LT.eval, LT.runEval, LT.runEval', LT.prettyTy, LT.runEvalWithGas) where

import LambdaBuffers.Compiler.LamTy.Eval qualified as LT
import LambdaBuffers.Compiler.LamTy.Pretty qualified as LT
Expand Down
19 changes: 18 additions & 1 deletion lambda-buffers-compiler/src/LambdaBuffers/Compiler/LamTy/Eval.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module LambdaBuffers.Compiler.LamTy.Eval (runEval, runEval', eval) where
module LambdaBuffers.Compiler.LamTy.Eval (runEval, runEval', eval, runEvalWithGas) where

import Control.Lens (makeLenses, view, (%~), (&), (.~), (^.))
import Control.Monad.Error.Class (MonadError (throwError))
Expand Down Expand Up @@ -35,6 +35,23 @@ runEval' mn tds ty =
let p = runReaderT (eval ty) (MkContext mn tds mempty)
in runExcept p

-- | `runEvalWithGas someGas ci mn ty` evaluates a `ty` repeatedly until the gas is reached or no further change (fixpoint reached)
runEvalWithGas :: Maybe Int -> PC.CompilerInput -> PC.ModuleName -> PC.Ty -> Either P.Error Ty
runEvalWithGas mayGas ci mn ty =
let tydefs = PC.indexTyDefs ci
in fix mayGas (runEval' mn tydefs) (fromTy ty)

fix :: Maybe Int -> (Ty -> Either e Ty) -> Ty -> Either e Ty
fix Nothing r x = case r x of
Left err -> Left err
Right x' -> if x == x' then Right x else fix Nothing r x'
fix (Just n) r x =
if n <= 0
then Right x
else case r x of
Left err -> Left err
Right x' -> if x == x' then Right x else fix (Just (n - 1)) r x'

eval :: MonadEval m => Ty -> m Ty
eval (TyRef tr) = eval (TyApp (TyRef tr) [] Nothing)
eval (TyApp (TyRef tr) args t) = resolveTyRef tr >>= (\f -> eval $ TyApp f args t)
Expand Down
36 changes: 35 additions & 1 deletion lambda-buffers-compiler/src/LambdaBuffers/ProtoCompat/Utils.hs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
{-# OPTIONS_GHC -Wno-orphans #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

module LambdaBuffers.ProtoCompat.Utils (prettyModuleName, prettyModuleName', localRef2ForeignRef, classClosure, filterClassInModule) where
module LambdaBuffers.ProtoCompat.Utils (prettyModuleName, prettyModuleName', localRef2ForeignRef, classClosure, filterClassInModule, collectTyVars, collectVars, collectPhantomTyArgs) where

import Control.Lens (Getter, to, view, (&), (.~), (^.))
import Data.Foldable (Foldable (toList))
import Data.Map qualified as Map
import Data.Map.Ordered (OMap)
import Data.Map.Ordered qualified as OMap
import Data.ProtoLens (Message (defMessage))
import Data.Set (Set)
import Data.Set qualified as Set
import LambdaBuffers.ProtoCompat.Indexing qualified as PC
import LambdaBuffers.ProtoCompat.InfoLess qualified as PC
import LambdaBuffers.ProtoCompat.IsCompat.FromProto qualified as PC
import LambdaBuffers.ProtoCompat.IsCompat.Lang ()
import LambdaBuffers.ProtoCompat.Types qualified as PC
Expand Down Expand Up @@ -89,3 +93,33 @@ filterDerive cls m drv = filterConstraint cls m (drv ^. #constraint)

filterConstraint :: Set PC.QClassName -> PC.Module -> PC.Constraint -> Bool
filterConstraint cls m cnstr = PC.qualifyClassRef (m ^. #moduleName) (cnstr ^. #classRef) `Set.member` cls

-- | `collectTyVars ty` scans the `PC.Ty` expression and collects all the type variables
collectTyVars :: PC.Ty -> [PC.Ty]
collectTyVars = fmap (`PC.withInfoLess` (PC.TyVarI . PC.TyVar)) . Set.toList . collectVars

-- | `collectVars ty` is similar to `collectTyVars` but returns type variable names
collectVars :: PC.Ty -> Set (PC.InfoLess PC.VarName)
collectVars = collectVars' mempty

collectVars' :: Set (PC.InfoLess PC.VarName) -> PC.Ty -> Set (PC.InfoLess PC.VarName)
collectVars' collected (PC.TyVarI tv) = Set.insert (PC.mkInfoLess . view #varName $ tv) collected
collectVars' collected (PC.TyAppI (PC.TyApp _ args _)) = collected `Set.union` (Set.unions . fmap collectVars $ args)
collectVars' collected _ = collected

collectPhantomTyArgs :: PC.TyDef -> [PC.TyArg]
collectPhantomTyArgs tyDef =
let
PC.TyAbs tyArgs tyBody _si = PC.tyAbs tyDef
tys :: [PC.Ty] = go [] tyBody tyArgs
vars = Set.unions $ collectVars <$> tys
args = Set.fromList [varName | (varName, _) <- OMap.assocs tyArgs]
phantomArgs = Set.difference args vars
in
[tyArg | (varName, tyArg) <- OMap.assocs tyArgs, varName `Set.member` phantomArgs]
where
go :: [PC.Ty] -> PC.TyBody -> OMap (PC.InfoLess PC.VarName) PC.TyArg -> [PC.Ty]
go tys (PC.SumI (PC.Sum ctors _si)) tyArgs = tys <> mconcat [go tys (PC.ProductI $ PC.product ctor) tyArgs | ctor <- toList ctors]
go tys (PC.ProductI (PC.Product fields _si)) _tyArgs = tys <> toList fields
go tys (PC.RecordI (PC.Record fields _si)) _tyArgs = tys <> [PC.fieldTy field | field <- toList fields]
go tys (PC.OpaqueI _) tyArgs = tys <> [PC.TyVarI . PC.TyVar . PC.argName $ tyArg | (_, tyArg) <- OMap.assocs tyArgs]
2 changes: 2 additions & 0 deletions lambda-buffers-compiler/test/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Test.LambdaBuffers.Compiler qualified as LBC
import Test.LambdaBuffers.Compiler.ClassClosure qualified as ClassClosure
import Test.LambdaBuffers.Compiler.LamTy qualified as LT
import Test.LambdaBuffers.Compiler.MiniLog qualified as ML
import Test.LambdaBuffers.Compiler.Phantoms qualified as Phantoms
import Test.LambdaBuffers.Compiler.TypeClassCheck qualified as TC
import Test.Tasty (defaultMain, testGroup)

Expand All @@ -21,4 +22,5 @@ main =
, ML.test
, TC.test
, ClassClosure.tests
, Phantoms.test
]
38 changes: 21 additions & 17 deletions lambda-buffers-compiler/test/Test/LambdaBuffers/Compiler/LamTy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ test =
(Just 1)
(U.lr "Beer" U.@ [U.fr ["Prelude"] "Text"])
"Prelude.Int8 Prelude.Text"
, fooCiTestCase
"Phantom a 1-> Prelude.Int8"
(Just 1)
(U.lr "Phantom" U.@ [U.tv "a"])
"Prelude.Int8"
, fooCiTestCase
"Phantom a *-> opq"
Nothing
(U.lr "Phantom" U.@ [U.tv "a"])
"opq"
, fooCiTestCase
"RecursiveA a 1-> opq"
(Just 1)
(U.lr "RecursiveA" U.@ [U.tv "a"])
"Prelude.Int8 (RecursiveA a)"
]

fooCi :: PC.CompilerInput
Expand All @@ -125,6 +140,8 @@ fooCi =
]
)
, U.td "Beer" (U.abs ["a"] $ U.prod' [U.fr ["Prelude"] "Int8", U.tv "a"])
, U.td "Phantom" (U.abs ["a"] $ U.prod' [U.fr ["Prelude"] "Int8"])
, U.td "RecursiveA" (U.abs ["a"] $ U.prod' [U.fr ["Prelude"] "Int8", U.lr "RecursiveA" U.@ [U.tv "a"]])
, U.td'maybe
, U.td'either
, U.td'list
Expand All @@ -135,20 +152,7 @@ fooCiTestCase :: TestName -> Maybe Int -> PC.Ty -> String -> TestTree
fooCiTestCase title mayGas ty want = testCase title $ runTestFix mayGas fooCi (U.mn ["Foo"]) ty want

runTestFix :: Maybe Int -> PC.CompilerInput -> PC.ModuleName -> PC.Ty -> String -> Assertion
runTestFix mayGas ci mn ty want =
let tydefs = PC.indexTyDefs ci
in case fix mayGas (LT.runEval' mn tydefs) (LT.fromTy ty) of
Left err -> assertFailure (show err)
Right res -> do
assertEqual "" want (show res)

fix :: Maybe Int -> (LT.Ty -> Either e LT.Ty) -> LT.Ty -> Either e LT.Ty
fix Nothing r x = case r x of
Left err -> Left err
Right x' -> if x == x' then Right x else fix Nothing r x'
fix (Just n) r x =
if n <= 0
then Right x
else case r x of
Left err -> Left err
Right x' -> if x == x' then Right x else fix (Just (n - 1)) r x'
runTestFix mayGas ci mn ty want = case LT.runEvalWithGas mayGas ci mn ty of
Left err -> assertFailure (show err)
Right res -> do
assertEqual "" want (show res)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module Test.LambdaBuffers.Compiler.Phantoms (test) where

import LambdaBuffers.ProtoCompat qualified as PC
import LambdaBuffers.ProtoCompat.Utils (collectPhantomTyArgs)
import Test.LambdaBuffers.ProtoCompat.Utils qualified as U
import Test.Tasty (TestTree, adjustOption, testGroup)
import Test.Tasty.HUnit (testCase, (@?=))
import Test.Tasty.Hedgehog qualified as H

test :: TestTree
test =
adjustOption (\_ -> H.HedgehogTestLimit $ Just 1000) $
testGroup
"Phantoms checks"
[ testCase "phantomA" $ phantoms U.td'phantomA @?= ["a"]
, testCase "phantomB" $ phantoms U.td'phantomB @?= ["a", "b"]
, testCase "phantomC" $ phantoms U.td'phantomC @?= ["a"]
, testCase "phantomD" $ phantoms U.td'phantomD @?= ["a", "b"]
, testCase "phantomE" $ phantoms U.td'phantomE @?= ["a"]
, testCase "phantomF" $ phantoms U.td'phantomF @?= ["a", "b"]
, testCase "Either" $ phantoms U.td'either @?= []
, testCase "Either opaque" $ phantoms U.td'eitherO @?= []
, testCase "List" $ phantoms U.td'list @?= []
, testCase "Maybe" $ phantoms U.td'maybe @?= []
, testCase "Maybe opaque" $ phantoms U.td'maybeO @?= []
]
where
phantoms = fmap ((\(PC.VarName name _) -> name) . PC.argName) . collectPhantomTyArgs
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ fails :: TestName -> PC.CompilerInput -> TestTree
fails title ci =
Golden.fails
goldensDir
(\tdir -> let fn = tdir </> "compiler_error" <.> "textproto" in (,) <$> (PbText.readMessageOrDie <$> Text.readFile fn) <*> pure fn)
(\tdir -> let fn = tdir </> "compiler_error" <.> "textproto" in ((,) . PbText.readMessageOrDie <$> Text.readFile fn) <*> pure fn)
(\otherFn gotErr -> Text.writeFile otherFn (Text.pack . show $ PbText.pprintMessage gotErr))
title
(fst $ TC.runCheck' ci)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ module Test.LambdaBuffers.ProtoCompat.Utils (
qcln',
mod'prelude'only'eq,
mod'prelude'noclass,
td'phantomA,
td'phantomB,
td'phantomC,
td'phantomD,
td'phantomE,
td'phantomF,
) where

import Control.Lens ((^.))
Expand Down Expand Up @@ -182,6 +188,24 @@ td'list = td "List" (abs ["a"] $ sum [("Nil", []), ("Cons", [tv "a", lr "List" @
td'listO :: PC.TyDef
td'listO = td "List" (abs ["a"] opq)

td'phantomA :: PC.TyDef
td'phantomA = td "PhantomA" (abs ["a"] $ sum [("A", [lr "Int", lr "Int"]), ("B", [lr "Int", lr "Int"])])

td'phantomB :: PC.TyDef
td'phantomB = td "PhantomB" (abs ["a", "b"] $ sum [("A", [lr "Int", lr "Int"]), ("B", [lr "Int", lr "Int"])])

td'phantomC :: PC.TyDef
td'phantomC = td "PhantomC" (abs ["a"] $ prod' [lr "Int", lr "Int"])

td'phantomD :: PC.TyDef
td'phantomD = td "PhantomD" (abs ["a", "b"] $ prod' [lr "Int", lr "Int"])

td'phantomE :: PC.TyDef
td'phantomE = td "PhantomE" (abs ["a"] $ recrd [("foo", lr "Int"), ("bar", lr "Int")])

td'phantomF :: PC.TyDef
td'phantomF = td "PhantomF" (abs ["a", "b"] $ recrd [("foo", lr "Int"), ("bar", lr "Int")])

-- | Some class definitions.
cd'eq :: PC.ClassDef
cd'eq = classDef ("Eq", "a") []
Expand Down
5 changes: 4 additions & 1 deletion settings.nix
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

index-state = lib.mkOption {
type = lib.types.str;
description = "Hackage index state to use when making a haskell.nix build environment";
description = "Hackage index state to use when making a haskell.nix
build environment";
};

compiler-nix-name = lib.mkOption {
Expand All @@ -50,6 +51,8 @@

tools = [

pkgs.haskellPackages.fourmolu
pkgs.haskellPackages.hlint
pkgs.haskellPackages.apply-refact

pkgs.nil
Expand Down

0 comments on commit 75f4411

Please sign in to comment.