Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extended Syntax: reducer #104

Merged
merged 15 commits into from
Jun 9, 2020
9 changes: 9 additions & 0 deletions grin/grin.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ library
Pipeline.Eval
Pipeline.Optimizations
Pipeline.Pipeline
Reducer.ExtendedSyntax.Base
Reducer.ExtendedSyntax.IO
Reducer.ExtendedSyntax.Pure
Reducer.ExtendedSyntax.PrimOps
Reducer.ExtendedSyntax.LLVM.JIT
Reducer.Base
Reducer.IO
Reducer.LLVM.Base
Expand Down Expand Up @@ -388,6 +393,10 @@ test-suite grin-test
AbstractInterpretation.SharingSpec
AbstractInterpretation.CreatedBySpec
Test.Hspec.PipelineExample

Reducer.ExtendedSyntax.BaseSpec
Reducer.ExtendedSyntax.IOSpec
Reducer.ExtendedSyntax.PureSpec
default-language: Haskell2010

benchmark grin-benchmark
Expand Down
7 changes: 7 additions & 0 deletions grin/src/Grin/ExtendedSyntax/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ module Grin.ExtendedSyntax.Pretty
, showName
, showWidth
, showWide
, KeyValueMap(..)
) where

import Data.Char
Expand Down Expand Up @@ -227,3 +228,9 @@ prettyFunction (name, (ret, args)) = pretty name <> align (encloseSep (text " ::

prettyLocSet :: Set Loc -> Doc
prettyLocSet = semiBraces . map (cyan . int) . Set.toList

newtype KeyValueMap k v = KV (Map k v)
deriving (Eq, Ord, Show)

instance (Pretty k, Pretty v) => Pretty (KeyValueMap k v) where
pretty (KV m) = prettyKeyValue $ Map.toList m
63 changes: 63 additions & 0 deletions grin/src/Reducer/ExtendedSyntax/Base.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{-# LANGUAGE LambdaCase, TupleSections, BangPatterns #-}
module Reducer.ExtendedSyntax.Base where

import Data.Map (Map)
import qualified Data.Map as Map
import Data.Foldable (fold)

import Text.PrettyPrint.ANSI.Leijen

import Grin.ExtendedSyntax.Grin
import Grin.ExtendedSyntax.Pretty

-- models cpu registers
type Env = Map Name RTVal

type SimpleRTVal = RTVal
data RTVal
= RT_ConstTagNode Tag [SimpleRTVal]
| RT_Unit
| RT_Lit Lit
| RT_Var Name
| RT_Loc Int
| RT_Undefined
deriving (Show, Eq, Ord)


instance Pretty RTVal where
pretty = \case
RT_ConstTagNode tag args -> parens $ hsep (pretty tag : map pretty args)
RT_Unit -> parens empty
RT_Lit lit -> pretty lit
RT_Var name -> pretty name
RT_Loc a -> keyword "loc" <+> int a
RT_Undefined -> keyword "undefined"

keyword :: String -> Doc
keyword = yellow . text

selectNodeItem :: Maybe Int -> RTVal -> RTVal
selectNodeItem Nothing val = val
selectNodeItem (Just i) (RT_ConstTagNode tag args) = args !! (i - 1)

bindPat :: Env -> RTVal -> BPat -> Env
bindPat env !val bPat = case bPat of
VarPat var -> Map.insert var val env
p@(AsPat tag args var) -> case val of
RT_ConstTagNode vtag vargs
| tag == vtag
, env' <- Map.insert var val env
, newVars <- fold $ zipWith Map.singleton args vargs
-> newVars <> env'
_ -> error $ "bindPat - illegal value for ConstTagNode: " ++ show val ++ " vs " ++ show (PP p)

evalVar :: Env -> Name -> RTVal
evalVar env n = Map.findWithDefault (error $ "missing variable: " ++ unpackName n) n env

evalVal :: Env -> Val -> RTVal
evalVal env = \case
Lit lit -> RT_Lit lit
Var n -> evalVar env n
ConstTagNode t a -> RT_ConstTagNode t $ map (evalVar env) a
Unit -> RT_Unit
Undefined t -> RT_Undefined
132 changes: 132 additions & 0 deletions grin/src/Reducer/ExtendedSyntax/IO.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
{-# LANGUAGE LambdaCase, TupleSections, BangPatterns, OverloadedStrings #-}
{-# LANGUAGE Strict #-}
module Reducer.ExtendedSyntax.IO (reduceFun) where

import Debug.Trace

import Control.Monad.RWS.Strict hiding (Alt)

import Data.Foldable (fold)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Vector.Mutable as Vector
import Data.IORef

import Reducer.ExtendedSyntax.Base
import Reducer.ExtendedSyntax.PrimOps
import Grin.ExtendedSyntax.Grin

-- models computer memory
data IOStore = IOStore
{ sVector :: IOVector RTVal
, sLast :: IORef Int
}

emptyStore1 :: IO IOStore
emptyStore1 = IOStore <$> new (10 * 1024 * 1024) <*> newIORef 0

type Prog = Map Name Def
type GrinS a = RWST Prog () IOStore IO a

getProg :: GrinS Prog
getProg = reader id

getStore :: GrinS IOStore
getStore = get

-- TODO: Resize
insertStore :: RTVal -> GrinS Int
insertStore x = do
(IOStore v l) <- getStore
lift $ do
n <- readIORef l
Vector.write v n x
writeIORef l (n + 1)
pure n

lookupStore :: Int -> GrinS RTVal
lookupStore n = do
(IOStore v _) <- getStore
lift $ do
Vector.read v n

updateStore :: Int -> RTVal -> GrinS ()
updateStore n x = do
(IOStore v _) <- getStore
lift $ do
Vector.write v n x

pprint exp = trace (f exp) exp where
f = \case
EBind a b _ -> unwords ["Bind", "{",show a,"} to {", show b, "}"]
ECase a _ -> unwords ["Case", show a]
SBlock {} -> "Block"
a -> show a


evalExp :: [External] -> Env -> Exp -> GrinS RTVal
evalExp exts env exp = case {-pprint-} exp of
EBind op pat exp -> evalSimpleExp exts env op >>= \v -> evalExp exts (bindPat env v pat) exp
-- TODO:
ECase scrut alts ->
let defaultAlts = [exp | Alt DefaultPat _ exp <- alts]
defaultAlt = if Prelude.length defaultAlts > 1
then error "multiple default case alternative"
else Prelude.take 1 defaultAlts

altNames = [ name | Alt _ name _ <- alts ]
scrutVal = evalVar env scrut
boundAltNames = fold $ map (`Map.singleton` scrutVal) altNames
env' = boundAltNames <> env
in case evalVar env scrut of
RT_ConstTagNode t l ->
let (vars,exp) = head $ [(b,exp) | Alt (NodePat a b) _ exp <- alts, a == t] ++ map ([],) defaultAlt ++ error ("evalExp - missing Case Node alternative for: " ++ show t)
go a [] [] = a
go a (x:xs) (y:ys) = go (Map.insert x y a) xs ys
go _ x y = error $ "invalid pattern and constructor: " ++ show (t,x,y)
in evalExp exts (go env' vars l) exp
RT_Lit l -> evalExp exts env' $ head $ [exp | Alt (LitPat a) _ exp <- alts, a == l] ++ defaultAlt ++ error ("evalExp - missing Case Lit alternative for: " ++ show l)
x -> error $ "evalExp - invalid Case dispatch value: " ++ show x
exp -> evalSimpleExp exts env exp

evalSimpleExp :: [External] -> Env -> SimpleExp -> GrinS RTVal
evalSimpleExp exts env = \case
SApp n a -> do
let args = map (evalVar env) a
go a [] [] = a
go a (x:xs) (y:ys) = go (Map.insert x y a) xs ys
go _ x y = error $ "invalid pattern for function: " ++ show (n,x,y)
if isExternalName exts n
then evalPrimOp n [] args
else do
Def _ vars body <- (Map.findWithDefault (error $ "unknown function: " ++ unpackName n) n) <$> getProg
evalExp exts (go env vars args) body
SReturn v -> pure $ evalVal env v
SStore v -> do
let v' = evalVar env v
l <- insertStore v'
-- modify' (\(StoreMap m s) -> StoreMap (IntMap.insert l v' m) (s+1))
pure $ RT_Loc l
SFetch ptr -> case evalVar env ptr of
RT_Loc l -> lookupStore l
x -> error $ "evalSimpleExp - Fetch expected location, got: " ++ show x
-- | FetchI Name Int -- fetch node component
SUpdate ptr var -> do
let v' = evalVar env var
case evalVar env ptr of
RT_Loc l -> updateStore l v' >> pure v'
x -> error $ "evalSimpleExp - Update expected location, got: " ++ show x
SBlock a -> evalExp exts env a
x -> error $ "evalSimpleExp: " ++ show x

reduceFun :: Program -> Name -> IO RTVal
reduceFun (Program exts l) n = do
store <- emptyStore1
(val, _, _) <- runRWST (evalExp exts mempty e) m store
pure val
where
m = Map.fromList [(n,d) | d@(Def n _ _) <- l]
e = case Map.lookup n m of
Nothing -> error $ "missing function: " ++ unpackName n
Just (Def _ [] a) -> a
_ -> error $ "function " ++ unpackName n ++ " has arguments"
82 changes: 82 additions & 0 deletions grin/src/Reducer/ExtendedSyntax/LLVM/JIT.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ForeignFunctionInterface #-}

module Reducer.ExtendedSyntax.LLVM.JIT where

import Grin.ExtendedSyntax.Grin (Val(..))
import Reducer.ExtendedSyntax.Base (RTVal(..))
import Data.String

import LLVM.Target
import LLVM.Context
import LLVM.Module
import qualified LLVM.AST as AST

import LLVM.OrcJIT
import qualified LLVM.Internal.OrcJIT.CompileLayer as CL

import Control.Monad.Except
import qualified Data.ByteString.Char8 as BS

import Data.Int
import Data.IORef
import Foreign.Ptr
import Foreign.Storable
import Foreign.Marshal.Alloc
import qualified Data.Map.Strict as Map

foreign import ccall "dynamic"
mkMain :: FunPtr (IO Int64) -> IO Int64

foreign import ccall "wrapper"
wrapIntPrint :: (Int64 -> IO ()) -> IO (FunPtr (Int64 -> IO ()))

withTestModule :: AST.Module -> (LLVM.Module.Module -> IO a) -> IO a
withTestModule mod f = withContext $ \context -> withModuleFromAST context mod f

myIntPrintImpl :: Int64 -> IO ()
myIntPrintImpl i = print i

resolver :: CompileLayer l => MangledSymbol -> l -> MangledSymbol -> IO (Either JITSymbolError JITSymbol)
resolver intPrint compileLayer symbol
| symbol == intPrint = do
funPtr <- wrapIntPrint myIntPrintImpl
let addr = ptrToWordPtr (castFunPtrToPtr funPtr)
pure $ Right (JITSymbol addr defaultJITSymbolFlags)
| otherwise = CL.findSymbol compileLayer symbol True

nullResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
nullResolver s = putStrLn "nullresolver" >> pure (Left (JITSymbolError "unknown symbol"))

failInIO :: ExceptT String IO a -> IO a
failInIO = either fail pure <=< runExceptT

grinHeapSize :: Int
grinHeapSize = 100 * 1024 * 1024

eagerJit :: AST.Module -> String -> IO RTVal
eagerJit amod mainName = do
resolvers <- newIORef Map.empty
withTestModule amod $ \mod ->
withHostTargetMachine $ \tm ->
withExecutionSession $ \es ->
withObjectLinkingLayer es (\k -> fmap (\rs -> rs Map.! k) (readIORef resolvers)) $ \linkingLayer ->
withIRCompileLayer linkingLayer tm $ \compileLayer -> do
intPrint <- mangleSymbol compileLayer "_prim_int_print"
withModuleKey es $ \k ->
withSymbolResolver es (SymbolResolver (resolver intPrint compileLayer)) $ \resolver -> do
modifyIORef' resolvers (Map.insert k resolver)
withModule compileLayer k mod $ do
mainSymbol <- mangleSymbol compileLayer (fromString mainName)
Right (JITSymbol mainFn _) <- CL.findSymbol compileLayer mainSymbol True
heapSymbol <- mangleSymbol compileLayer (fromString "_heap_ptr_")
Right (JITSymbol heapWordPtr _) <- CL.findSymbol compileLayer heapSymbol True
-- allocate GRIN heap
heapPointer <- callocBytes grinHeapSize :: IO (Ptr Int8)
poke (wordPtrToPtr heapWordPtr :: Ptr Int64) (fromIntegral $ minusPtr heapPointer nullPtr)
-- run function
result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn))
-- TODO: read back the result and build the haskell value represenation
-- free GRIN heap
free heapPointer
pure RT_Unit
Loading