diff --git a/grin/grin.cabal b/grin/grin.cabal index 9a645a83..fc58e35a 100644 --- a/grin/grin.cabal +++ b/grin/grin.cabal @@ -128,6 +128,11 @@ library Reducer.LLVM.TypeGen Reducer.PrimOps Reducer.Pure + Reducer.ExtendedSyntax.LLVM.Base + Reducer.ExtendedSyntax.LLVM.CodeGen + Reducer.ExtendedSyntax.LLVM.InferType + Reducer.ExtendedSyntax.LLVM.PrimOps + Reducer.ExtendedSyntax.LLVM.TypeGen Test.Assertions Test.Check Test.Grammar diff --git a/grin/src/Reducer/ExtendedSyntax/LLVM/Base.hs b/grin/src/Reducer/ExtendedSyntax/LLVM/Base.hs new file mode 100644 index 00000000..725278b5 --- /dev/null +++ b/grin/src/Reducer/ExtendedSyntax/LLVM/Base.hs @@ -0,0 +1,170 @@ +{-# LANGUAGE LambdaCase, TupleSections, DataKinds, RecursiveDo, RecordWildCards, OverloadedStrings, TemplateHaskell #-} + +module Reducer.ExtendedSyntax.LLVM.Base where + +import Text.Printf +import Control.Monad as M +import Control.Monad.State +import Data.Functor.Foldable as Foldable +import Lens.Micro.Platform + +import Data.Word +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Text (Text) +import Data.Vector (Vector) + +import Grin.ExtendedSyntax.Grin as Grin +import qualified Grin.ExtendedSyntax.TypeEnv as TypeEnv + +import LLVM.AST as LLVM hiding (callingConvention) +import LLVM.AST.Type as LLVM +import LLVM.AST.AddrSpace +import LLVM.AST.Constant hiding (Add, ICmp) +import LLVM.AST.IntegerPredicate +import qualified LLVM.AST.CallingConvention as CC +import qualified LLVM.AST.Linkage as L +import qualified LLVM.AST as AST +import LLVM.AST.Global +import LLVM.Context +import LLVM.Module + +import Control.Monad.Except +import qualified Data.ByteString.Char8 as BS + +heapPointerName :: String +heapPointerName = "_heap_ptr_" + +tagLLVMType :: LLVM.Type +tagLLVMType = i64 + +locationLLVMType :: LLVM.Type +locationLLVMType = ptr tagLLVMType + +mkNameG :: Grin.Name -> AST.Name +mkNameG = mkName . Grin.unpackName + +data Env + = Env + { _envDefinitions :: [Definition] -- Program state + , _envBasicBlocks :: Map Int BasicBlock -- Def state ; order -> basic block + , _envInstructions :: [Named Instruction] -- Def state + , _constantMap :: Map Grin.Name Operand -- Def state + , _currentBlockName :: AST.Name -- Def state + , _envBlockInstructions :: Map AST.Name [Named Instruction] -- Def state + , _envBlockOrder :: Map AST.Name Int -- Def state + , _envTempCounter :: Int + , _envTypeEnv :: TypeEnv.TypeEnv + , _envTagMap :: Map Tag Constant + , _envStringMap :: Map Text AST.Name -- Grin String Literal -> AST.Name + , _envStringCounter :: Int + } + +emptyEnv = Env + { _envDefinitions = mempty + , _envBasicBlocks = mempty + , _envInstructions = mempty + , _constantMap = mempty + , _currentBlockName = mkName "" + , _envBlockInstructions = mempty + , _envBlockOrder = mempty + , _envTempCounter = 0 + , _envTypeEnv = TypeEnv.emptyTypeEnv + , _envTagMap = mempty + , _envStringMap = mempty + , _envStringCounter = 0 + } + +concat <$> mapM makeLenses [''Env] + +-- Tagged union +{- + HINT: tagged union LLVM representation + + struct { + Int64 tag; + Int64[N1]; + Word64[N2]; + ... + } +-} +data TUIndex + = TUIndex + { tuStructIndex :: Word32 + , tuArrayIndex :: Word32 + , tuItemLLVMType :: LLVM.Type + } + deriving (Eq, Ord, Show) + +data TaggedUnion + = TaggedUnion + { tuLLVMType :: LLVM.Type -- struct of arrays of SimpleType with size + , tuMapping :: Map Tag (Vector TUIndex) + } + deriving (Eq, Ord, Show) + +data CGType + = CG_SimpleType + { cgLLVMType :: LLVM.Type + , cgType :: TypeEnv.Type + } + | CG_NodeSet + { cgLLVMType :: LLVM.Type + , cgType :: TypeEnv.Type + , cgTaggedUnion :: TaggedUnion + } + deriving (Eq, Ord, Show) + +type CG = State Env + +emit :: [Named Instruction] -> CG () +emit instructions = modify' (\env@Env{..} -> env {_envInstructions = _envInstructions ++ instructions}) + +addConstant :: Grin.Name -> Operand -> CG () +addConstant name operand = modify' (\env@Env{..} -> env {_constantMap = Map.insert name operand _constantMap}) + +unit :: Operand +unit = ConstantOperand $ Undef VoidType + +undef :: Type -> Operand +undef = ConstantOperand . Undef + +data Result + = I CGType Instruction + | O CGType Operand + +-- utils +closeBlock :: Terminator -> CG () +closeBlock tr = modify' $ \env@Env{..} -> env + { _envInstructions = mempty + , _envBasicBlocks = Map.insert (Map.findWithDefault undefined _currentBlockName _envBlockOrder) (BasicBlock _currentBlockName _envInstructions (Do tr)) _envBasicBlocks + , _envBlockInstructions = Map.delete _currentBlockName _envBlockInstructions + , _currentBlockName = mkName "" + } + +activeBlock :: AST.Name -> CG () +activeBlock name = modify' f where + f env@Env{..} + | name == _currentBlockName = env + | otherwise = env + { _envInstructions = Map.findWithDefault mempty name _envBlockInstructions + , _currentBlockName = name + , _envBlockInstructions = Map.insert _currentBlockName _envInstructions _envBlockInstructions + , _envBlockOrder = Map.insert name (Map.findWithDefault (Map.size _envBlockOrder) name _envBlockOrder) _envBlockOrder + } + +uniqueName :: Grin.Name -> CG AST.Name +uniqueName name = state (\env@Env{..} -> (mkName $ printf "%s.%d" (unpackName name) _envTempCounter, env {_envTempCounter = succ _envTempCounter})) + +getOperand :: Grin.Name -> Result -> CG (CGType, Operand) +getOperand name = \case + O cgTy a -> pure (cgTy, a) + I cgTy i -> case cgLLVMType cgTy of + VoidType -> emit [Do i] >> pure (cgTy, unit) + t -> (cgTy,) <$> codeGenLocalVar name t i + +codeGenLocalVar :: Grin.Name -> LLVM.Type -> AST.Instruction -> CG LLVM.Operand +codeGenLocalVar name ty instruction = do + varName <- uniqueName name + emit [varName := instruction] + pure $ LocalReference ty varName diff --git a/grin/src/Reducer/ExtendedSyntax/LLVM/CodeGen.hs b/grin/src/Reducer/ExtendedSyntax/LLVM/CodeGen.hs new file mode 100644 index 00000000..558001d6 --- /dev/null +++ b/grin/src/Reducer/ExtendedSyntax/LLVM/CodeGen.hs @@ -0,0 +1,602 @@ +{-# LANGUAGE LambdaCase, TupleSections, DataKinds, RecursiveDo, RecordWildCards, OverloadedStrings #-} + +module Reducer.ExtendedSyntax.LLVM.CodeGen + ( codeGen + , toLLVM + ) where + +import Text.Printf +import Control.Monad as M +import Control.Monad.State +import Data.Functor.Foldable as Foldable +import Lens.Micro.Platform + +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Vector (Vector) +import qualified Data.Vector as V +import qualified Data.List as List +import qualified Data.Text as Text +import qualified Data.ByteString.Short as ShortByteString +import Data.String (fromString) +import Text.Printf (printf) +import Lens.Micro.Mtl + +import LLVM.AST hiding (callingConvention, functionAttributes) +import LLVM.AST.AddrSpace +import LLVM.AST.Type as LLVM +import qualified LLVM.AST.Typed as LLVM +import LLVM.AST.Constant as C hiding (Add, ICmp) +import LLVM.AST.IntegerPredicate +import qualified LLVM.AST.CallingConvention as CC +import qualified LLVM.AST.Linkage as L +import qualified LLVM.AST as AST +import qualified LLVM.AST.Float as F +import qualified LLVM.AST.FunctionAttribute as FA +import qualified LLVM.AST.RMWOperation as RMWOperation +import LLVM.AST.Global as Global +import LLVM.Context +import LLVM.Module + +import Control.Monad.Except +import qualified Data.ByteString.Char8 as BS +import qualified Data.ByteString.Short as BSShort + +import Grin.ExtendedSyntax.Grin as Grin +import Grin.ExtendedSyntax.Pretty +import Grin.ExtendedSyntax.TypeEnv hiding (Type, typeOfVal) +import qualified Grin.ExtendedSyntax.TypeEnv as TypeEnv +import Reducer.ExtendedSyntax.LLVM.Base +import Reducer.ExtendedSyntax.LLVM.PrimOps +import Reducer.ExtendedSyntax.LLVM.TypeGen +import Reducer.ExtendedSyntax.LLVM.InferType + + + +debugMode :: Bool +debugMode = True + +toLLVM :: String -> AST.Module -> IO BS.ByteString +toLLVM fname mod = withContext $ \ctx -> do + llvm <- withModuleFromAST ctx mod moduleLLVMAssembly + BS.writeFile fname llvm + pure llvm + +codeGenLit :: Lit -> CG C.Constant +codeGenLit = \case + LInt64 v -> pure $ Int {integerBits=64, integerValue=fromIntegral v} + LWord64 v -> pure $ Int {integerBits=64, integerValue=fromIntegral v} + LFloat v -> pure $ C.Float {floatValue=F.Single v} + LBool v -> pure $ Int {integerBits=1, integerValue=if v then 1 else 0} + LChar v -> pure $ Int {integerBits=8, integerValue=fromIntegral $ fromEnum v} + LString v -> C.GlobalReference stringType <$> strName v + +strName :: Text.Text -> CG AST.Name +strName str = do + mName <- use $ envStringMap . at str + case mName of + Just n -> pure n + Nothing -> do + counter <- envStringCounter <<%= succ + let n = Name $ fromString $ "str." ++ show counter + envStringMap %= Map.insert str n + pure n + +codeGenVar :: Grin.Name -> CG Operand +codeGenVar var = do + Map.lookup var <$> gets _constantMap >>= \case + -- QUESTION: what is this? + Nothing -> do + ty <- getVarType var + pure $ LocalReference (cgLLVMType ty) (mkNameG var) + Just operand -> pure operand + +codeGenVal :: Val -> CG Operand +codeGenVal val = case val of + ConstTagNode tag args -> do + opArgs <- mapM codeGenVar args + + valT <- typeOfVal val + let T_NodeSet ns = valT + + ty <- typeOfVal val + let cgTy = toCGType ty + TaggedUnion{..} = cgTaggedUnion cgTy + nodeName = packName $ printf "node_%s" (show $ PP tag) + -- set node items + build agg (item, TUIndex{..}) = do + codeGenLocalVar nodeName (cgLLVMType cgTy) $ AST.InsertValue + { aggregate = agg + , element = item + , indices' = [1 + tuStructIndex, tuArrayIndex] + , metadata = [] + } + + -- set tag + tagId <- getTagId tag + let agg0 = ConstantOperand $ C.Struct + { structName = Nothing + , isPacked = True + , memberValues = tagId : [C.Undef t | t <- tail $ elementTypes tuLLVMType] + } + foldM build agg0 $ zip opArgs $ V.toList $ Map.findWithDefault undefined tag tuMapping + + Unit -> pure unit + Lit lit -> ConstantOperand <$> codeGenLit lit + Var name -> codeGenVar name + + Undefined t -> pure . ConstantOperand . Undef . cgLLVMType. toCGType $ t + + _ -> error $ printf "codeGenVal: %s" (show $ pretty val) + +getCPatConstant :: CPat -> CG Constant +getCPatConstant = \case + LitPat lit -> codeGenLit lit + NodePat tag args -> getTagId tag + DefaultPat -> pure C.TokenNone + +getCPatName :: CPat -> Grin.Name +getCPatName = \case + LitPat lit -> case lit of + LInt64 v -> "int_" <> showTS v + LWord64 v -> "word_" <> showTS v + LBool v -> "bool_" <> showTS v + LChar v -> "char_" <> showTS v + LString v -> error "pattern match on string is not supported" + LFloat v -> error "pattern match on float is not supported" + other -> error $ "pattern match not implemented: " ++ show other + NodePat tag _ -> tagName tag + DefaultPat -> "default" + where + tagName (Tag c name) = showTS c <> name + +-- https://stackoverflow.com/questions/6374355/llvm-assembly-assign-integer-constant-to-register +{- + NOTE: if the cata result is a monad then it means the codegen is sequential + + IDEA: write an untyped codegen ; which does not rely on HPT calculated types +-} + +toModule :: Env -> AST.Module +toModule Env{..} = defaultModule + { moduleName = "basic" + , moduleDefinitions = heapPointerDef : (stringDefinitions) ++ (reverse _envDefinitions) + } + where + heapPointerDef = GlobalDefinition globalVariableDefaults + { name = mkName (heapPointerName) + , Global.type' = i64 + , initializer = Just $ Int 64 0 + } + + stringDefinitions = concat + [ [ GlobalDefinition globalVariableDefaults + { name = valAstName + , Global.type' = ArrayType (fromIntegral (length stringVal)) i8 + , initializer = Just $ C.Array i8 $ [Int 8 $ fromIntegral $ fromEnum v0 | v0 <- stringVal] + } + , GlobalDefinition globalVariableDefaults + { name = astName + , Global.type' = stringStructType + , initializer = Just $ C.Struct Nothing False -- TODO: Set struct name + [ C.GetElementPtr + { inBounds = True + , address = GlobalReference (PointerType (ArrayType (fromIntegral (length stringVal)) i8) (AddrSpace 0)) valAstName + , indices = [Int {integerBits=64, integerValue=0}, Int {integerBits=64, integerValue=0}] + } + , Int 64 $ fromIntegral $ length stringVal + ] + } + ] + | (stringVal0, astName@(Name astNameBS)) <- Map.toList _envStringMap + , let stringVal = Text.unpack stringVal0 + , let valAstName = Name $ BSShort.pack $ (BSShort.unpack astNameBS) ++ (BSShort.unpack ".val") -- Append ShortByteStrings + ] + +{- + type of: + ok - SApp Name [SimpleVal] + ok - SReturn Val + ?? - SStore Val + ?? - SFetchI Name (Maybe Int) -- fetch a full node or a single node item in low level GRIN + ok - SUpdate Name Val +-} +codeGen :: TypeEnv -> Exp -> AST.Module +codeGen typeEnv exp = toModule $ flip execState (emptyEnv {_envTypeEnv = typeEnv}) $ para folder exp where + folder :: ExpF (Exp, CG Result) -> CG Result + folder = \case + SReturnF val -> do + ty <- typeOfVal val + O (toCGType ty) <$> codeGenVal val + SBlockF a -> snd $ a + + EBindF (leftExp, leftResultM) bpat (_,rightResultM) -> do + leftResult <- case leftExp of + -- FIXME: this is an ugly hack to compile SStore ; because it requires the binder name to for type lookup + -- QUESTION: can't we just use the type of var here? (varT <- (typeOfVar >=> toCGType) var) + SStore var -> do + let varName = _bPatVar bpat + varT <- getVarType varName + nodeLocation <- codeGenIncreaseHeapPointer varT + codeGenStoreNode var nodeLocation -- TODO + pure $ O locationCGType nodeLocation + + -- normal case ; this should be the only case here normally + _ -> leftResultM + case bpat of + AsPat tag args asVarName -> do + (cgTy,operand) <- getOperand ("node_" <> showTS (PP tag)) leftResult + addConstant asVarName operand + let mapping = tuMapping $ cgTaggedUnion cgTy + -- bind node pattern variables + forM_ (zip (V.toList $ Map.findWithDefault undefined tag mapping) args) $ \(TUIndex{..}, arg) -> do + let indices = [1 + tuStructIndex, tuArrayIndex] + emit [(mkNameG arg) := AST.ExtractValue {aggregate = operand, indices' = indices, metadata = []}] + VarPat name -> do + getOperand name leftResult >>= addConstant name . snd + _ -> getOperand "tmp" leftResult >> pure () + rightResultM + + SAppF name args -> do + (retType, argTypes) <- getFunctionType name + operands <- mapM codeGenVar args + operandsTypes <- mapM (fmap toCGType . typeOfVar) args + -- convert values to function argument type + convertedArgs <- sequence $ zipWith3 codeGenValueConversion operandsTypes operands argTypes + let findExternalName :: TypeEnv.Name -> Maybe External + findExternalName n = List.find ((n ==) . eName) (externals exp) + case findExternalName name of + Just e -> codeExternal e convertedArgs + Nothing -> do + -- call to top level functions + let functionType = FunctionType + { resultType = cgLLVMType retType + , argumentTypes = map cgLLVMType argTypes + , isVarArg = False + } + pure . I retType $ AST.Call + { tailCallKind = Just Tail + , callingConvention = CC.Fast + , returnAttributes = [] + , function = Right . ConstantOperand $ GlobalReference (ptr functionType) (mkNameG name) + , arguments = zip convertedArgs (repeat []) + , functionAttributes = [] + , metadata = [] + } + + AltF _ _ a -> snd a + + ECaseF scrut alts -> typeOfVar scrut >>= \case -- distinct implementation for tagged unions and simple types + T_SimpleType{} -> do + opVal <- codeGenVar scrut + codeGenCase opVal alts $ \_ -> pure () + + T_NodeSet nodeSet -> do + tuScrut <- codeGenVar scrut + tagVal <- codeGenExtractTag tuScrut + let valTU = taggedUnion nodeSet + codeGenCase tagVal alts $ \case + NodePat tag args -> case Map.lookup tag $ tuMapping valTU of + Nothing -> pure () + Just mapping -> do + --let mapping = Map.findWithDefault undefined tag $ tuMapping valTU + -- bind cpat variables + forM_ (zip args $ V.toList mapping) $ \(argName, TUIndex{..}) -> do + let indices = [1 + tuStructIndex, tuArrayIndex] + emit [(mkNameG argName) := AST.ExtractValue {aggregate = tuScrut, indices' = indices, metadata = []}] + DefaultPat -> pure () + _ -> error "not implemented" + + DefF name args (_,body) -> do + -- clear def local state + let clearDefState = modify' $ \env -> env + { _envBasicBlocks = mempty + , _envInstructions = mempty + , _constantMap = mempty + , _currentBlockName = mkName "" + , _envBlockInstructions = mempty + , _envBlockOrder = mempty + } + clearDefState + activeBlock (mkNameG $ name <> ".entry") + (cgTy, result) <- body >>= getOperand ("result." <> name) + (cgRetType, argTypes) <- getFunctionType name + let llvmReturnType = cgLLVMType cgRetType + + returnValue <- codeGenValueConversion cgTy result cgRetType + + closeBlock $ Ret + { returnOperand = if llvmReturnType == VoidType then Nothing else Just returnValue + , metadata' = [] + } + + when debugMode $ do + errorBlock + blockInstructions <- Map.delete (mkName "") <$> gets _envBlockInstructions + unless (Map.null blockInstructions) $ error $ printf "unclosed blocks in %s\n %s" name (show blockInstructions) + blocks <- gets _envBasicBlocks + let def = GlobalDefinition functionDefaults + { name = mkNameG name + , parameters = ([Parameter (cgLLVMType argType) (mkNameG a) [] | (a, argType) <- zip args argTypes], False) -- HINT: False - no var args + , returnType = llvmReturnType + , basicBlocks = Map.elems blocks + , callingConvention = if name == "grinMain" then CC.C else CC.Fast + , linkage = if name == "grinMain" then L.External else L.Private + , functionAttributes = [Right $ FA.StringAttribute "no-jump-tables" "true"] + } + clearDefState + modify' (\env@Env{..} -> env {_envDefinitions = def : _envDefinitions}) + pure $ O unitCGType unit + + ProgramF exts defs -> do + -- register prim fun lib + mapM registerPrimFunLib exts + sequence_ (map snd defs) >> pure (O unitCGType unit) + + SFetchF name -> do + -- load tag + tagAddress <- codeGenVal $ Var name + tagVal <- codeGenLocalVar "tag" tagLLVMType $ Load + { volatile = False + , address = tagAddress + , maybeAtomicity = Nothing + , alignment = 1 + , metadata = [] + } + -- switch on possible tags + TypeEnv{..} <- gets _envTypeEnv + let locs = case Map.lookup name _variable of + Just (T_SimpleType (T_Location l)) -> l + Just (T_SimpleType T_UnspecifiedLocation) -> [] + x -> error $ printf "variable %s can not be fetched, %s is not a location type" name (show $ pretty x) + nodeSet = mconcat [_location V.! loc | loc <- locs] + resultCGType = toCGType $ T_NodeSet nodeSet + resultTU = cgTaggedUnion resultCGType + codeGenTagSwitch tagVal nodeSet $ \tag items -> do + let nodeCGType = toCGType $ T_NodeSet $ Map.singleton tag items + nodeTU = cgTaggedUnion nodeCGType + nodeAddress <- codeGenBitCast ("ptr_" <> showTS (PP tag)) tagAddress (ptr $ tuLLVMType nodeTU) + nodeVal <- codeGenLocalVar ("node_" <> showTS (PP tag)) (cgLLVMType nodeCGType) $ Load + { volatile = False + , address = nodeAddress + , maybeAtomicity = Nothing + , alignment = 1 + , metadata = [] + } + (resultCGType,) <$> copyTaggedUnion nodeVal nodeTU resultTU + + SUpdateF name var -> do + nodeLocation <- codeGenVal $ Var name + codeGenStoreNode var nodeLocation + pure $ O unitCGType unit + + SStoreF var -> do + varTy <- typeOfVar var + nodeLocation <- codeGenIncreaseHeapPointer $ toCGType varTy + codeGenStoreNode var nodeLocation + pure $ O locationCGType nodeLocation + + expF -> error $ printf "missing codegen for:\n%s" (show $ pretty $ embed $ fmap fst expF) + +codeGenStoreNode :: Grin.Name -> Operand -> CG () +codeGenStoreNode var nodeLocation = do + tuVal <- codeGenVar var + tagVal <- codeGenExtractTag tuVal + varT <- typeOfVar var + let T_NodeSet nodeSet = varT + + let valueTU = taggedUnion nodeSet + codeGenTagSwitch tagVal nodeSet $ \tag items -> do + let nodeTU = taggedUnion $ Map.singleton tag items + nodeVal <- copyTaggedUnion tuVal valueTU nodeTU + nodeAddress <- codeGenBitCast ("ptr_" <> showTS (PP tag)) nodeLocation (ptr $ tuLLVMType nodeTU) + emit [Do Store + { volatile = False + , address = nodeAddress + , value = nodeVal + , maybeAtomicity = Nothing + , alignment = 1 + , metadata = [] + }] + pure $ (unitCGType, unit) + pure () + +convertStringOperand t o = case (cgType t,o) of + (T_SimpleType T_String, ConstantOperand stringRef@(GlobalReference{})) + -> ConstantOperand $ C.GetElementPtr + { inBounds = False + , address = stringRef + , indices = [Int {integerBits=64, integerValue=0}, Int {integerBits=64, integerValue=0}] + } + _ -> o + +codeGenCase :: Operand -> [(Alt, CG Result)] -> (CPat -> CG ()) -> CG Result +codeGenCase opVal alts bindingGen = do + curBlockName <- gets _currentBlockName + + let isDefault = \case + (Alt DefaultPat _ _, _) -> True + _ -> False + (defaultAlts, normalAlts) = List.partition isDefault alts + altNames = [ altName | Alt _ altName _ <- map fst alts ] + when (length defaultAlts > 1) $ fail "multiple default patterns" + let orderedAlts = defaultAlts ++ normalAlts + + mapM_ (flip addConstant opVal) altNames + + (altDests, altValues, altCGTypes) <- fmap List.unzip3 . forM orderedAlts $ \(Alt cpat _ _altName, altBody) -> do + altCPatVal <- getCPatConstant cpat + altEntryBlock <- uniqueName ("block." <> getCPatName cpat) + activeBlock altEntryBlock + + bindingGen cpat + + altResult <- altBody + (altCGTy, altOp) <- getOperand ("result." <> getCPatName cpat) altResult + + lastAltBlock <- gets _currentBlockName + + pure ((altCPatVal, altEntryBlock), (altOp, lastAltBlock, altCGTy), altCGTy) + + let resultCGType = commonCGType altCGTypes + switchExit <- uniqueName "block.exit" -- this is the next block + + altConvertedValues <- forM altValues $ \(altOp, lastAltBlock, altCGTy) -> do + activeBlock lastAltBlock + -- HINT: convert alt result to common type + convertedAltOp <- codeGenValueConversion altCGTy altOp resultCGType + closeBlock $ Br + { dest = switchExit + , metadata' = [] + } + pure (convertedAltOp, lastAltBlock) + + activeBlock curBlockName + let (defaultDest, normalAltDests) = if null defaultAlts + then (if debugMode then mkName "error_block" else switchExit, altDests) + else (snd $ head altDests, tail altDests) + closeBlock $ Switch + { operand0' = opVal + , defaultDest = defaultDest -- QUESTION: do we want to catch this error? + , dests = normalAltDests + , metadata' = [] + } + + activeBlock switchExit + + pure . I resultCGType $ Phi + { type' = cgLLVMType resultCGType + , incomingValues = altConvertedValues ++ if debugMode then [] else [(undef (cgLLVMType resultCGType), curBlockName)] + , metadata = [] + } + +-- merge heap pointers from alt branches +codeGenTagSwitch :: Operand -> NodeSet -> (Tag -> Vector SimpleType -> CG (CGType, Operand)) -> CG Result +codeGenTagSwitch tagVal nodeSet tagAltGen | Map.size nodeSet > 1 = do + let possibleNodes = Map.toList nodeSet + curBlockName <- gets _currentBlockName + + (altDests, altValues, altCGTypes) <- fmap List.unzip3 . forM possibleNodes $ \(tag, items) -> do + altEntryBlock <- uniqueName ("block." <> tagName tag) + altCPatVal <- getTagId tag + activeBlock altEntryBlock + + (altCGTy, altOp) <- tagAltGen tag items + + lastAltBlock <- gets _currentBlockName + + pure ((altCPatVal, altEntryBlock), (altOp, lastAltBlock, altCGTy), altCGTy) + + let resultCGType = commonCGType altCGTypes + switchExit <- uniqueName "block.exit" -- this is the next block + + altConvertedValues <- forM altValues $ \(altOp, lastAltBlock, altCGTy) -> do + activeBlock lastAltBlock + -- HINT: convert alt result to common type + convertedAltOp <- codeGenValueConversion altCGTy altOp resultCGType + + closeBlock $ Br + { dest = switchExit + , metadata' = [] + } + pure (convertedAltOp, lastAltBlock) + + activeBlock curBlockName + closeBlock $ Switch + { operand0' = tagVal + , defaultDest = if debugMode then mkName "error_block" else switchExit + , dests = altDests + , metadata' = [] + } + + activeBlock switchExit + + pure . I resultCGType $ Phi + { type' = cgLLVMType resultCGType + , incomingValues = altConvertedValues ++ if debugMode then [] else [(undef (cgLLVMType resultCGType), curBlockName)] + , metadata = [] + } + +codeGenTagSwitch tagVal nodeSet tagAltGen | [(tag, items)] <- Map.toList nodeSet = do + uncurry O <$> tagAltGen tag items + +codeGenTagSwitch tagVal nodeSet tagAltGen = error $ "LLVM codegen: empty node set for " ++ show tagVal + +-- heap pointer related functions + +codeGenIncreaseHeapPointer :: CGType -> CG Operand -- TODO +codeGenIncreaseHeapPointer varT = do + -- increase heap pointer and return the old value which points to the first free block + nodeSet <- case varT of + CG_SimpleType {cgType = T_SimpleType (T_Location locs)} -> mconcat <$> mapM (\loc -> use $ envTypeEnv.location.ix loc) locs + CG_NodeSet {cgType = T_NodeSet ns} -> pure ns + _ -> error $ show varT + + let tuPtrTy = ptr $ tuLLVMType $ taggedUnion nodeSet + tuSizePtr <- codeGenLocalVar "alloc_bytes" tuPtrTy $ AST.GetElementPtr + { inBounds = True + , address = ConstantOperand $ Null tuPtrTy + , indices = [ConstantOperand $ C.Int 32 1] + , metadata = [] + } + tuSizeInt <- codeGenLocalVar "alloc_bytes" i64 $ AST.PtrToInt + { operand0 = tuSizePtr + , type' = i64 + , metadata = [] + } + heapInt <- codeGenLocalVar "new_node_ptr" i64 $ AST.AtomicRMW + { volatile = False + , rmwOperation = RMWOperation.Add + , address = ConstantOperand $ GlobalReference (ptr i64) (mkName heapPointerName) + , value = tuSizeInt + , atomicity = (System, Monotonic) + , metadata = [] + } + codeGenLocalVar "new_node_ptr" (ptr i64) $ AST.IntToPtr + { operand0 = heapInt + , type' = ptr i64 + , metadata = [] + } + +external :: Type -> AST.Name -> [(Type, AST.Name)] -> CG () +external retty label argtys = modify' (\env@Env{..} -> env {_envDefinitions = def : _envDefinitions}) where + def = GlobalDefinition $ functionDefaults + { name = label + , linkage = L.External + , parameters = ([Parameter ty nm [] | (ty, nm) <- argtys], False) + , returnType = retty + , basicBlocks = [] + } + +-- available primitive functions +registerPrimFunLib :: External -> CG () +registerPrimFunLib ext = do + external + (toLLVMType $ eRetType ext) + (mkName $ Text.unpack $ unNM $ eName ext) + [ (toLLVMType t, mkName ("x" ++ show n)) | (t,n) <- (eArgsType ext) `zip` [1..] ] + where + toLLVMType = \case + TySimple t -> typeGenSimpleType t + rest -> error $ "Unsupported type:" ++ show rest + +errorBlock = do + activeBlock $ mkName "error_block" + let functionType = FunctionType + { resultType = VoidType + , argumentTypes = [i64] + , isVarArg = False + } + + emit [Do Call + { tailCallKind = Just Tail + , callingConvention = CC.C + , returnAttributes = [] + , function = Right . ConstantOperand $ GlobalReference (ptr functionType) (mkName "_prim_int_print") + , arguments = zip [ConstantOperand $ C.Int 64 666] (repeat []) + , functionAttributes = [] + , metadata = [] + }] + closeBlock $ Unreachable [] diff --git a/grin/src/Reducer/ExtendedSyntax/LLVM/InferType.hs b/grin/src/Reducer/ExtendedSyntax/LLVM/InferType.hs new file mode 100644 index 00000000..1da6cbbe --- /dev/null +++ b/grin/src/Reducer/ExtendedSyntax/LLVM/InferType.hs @@ -0,0 +1,49 @@ +{-# LANGUAGE LambdaCase, TupleSections, RecordWildCards, OverloadedStrings, TemplateHaskell #-} + +module Reducer.ExtendedSyntax.LLVM.InferType where + +import Text.Printf + +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Vector (Vector) +import qualified Data.Vector as V + +import Control.Monad.State +import Lens.Micro.Platform + +import Reducer.ExtendedSyntax.LLVM.Base +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TypeEnv hiding (typeOfVal) +import Grin.ExtendedSyntax.Pretty + +-- TODO: replace this module with a more generic one that could be used by other components also + +-- allows simple type singletons or locations +validateNodeItem :: Type -> CG () +validateNodeItem ts@T_NodeSet{} = error $ printf "LLVM codegen: illegal node item type %s" (show $ pretty ts) +validateNodeItem _ = pure () + +nodeType :: Tag -> [Type] -> CG Type +nodeType tag items = do + mapM_ validateNodeItem items + pure $ T_NodeSet $ Map.singleton tag $ V.fromList $ map _simpleType items + +typeOfVar :: Name -> CG Type +typeOfVar var = use (envTypeEnv.variable.at var) >>= \case + Nothing -> error $ printf "unknown variable %s" var + Just ty -> pure ty + +typeOfVal :: Val -> CG Type +typeOfVal val = do + case val of + ConstTagNode tag args -> mapM typeOfVar args >>= nodeType tag + {- + VarTagNode Name [SimpleVal] -- complete node (variable tag) + ValTag Tag + -} + Unit -> pure $ T_SimpleType T_Unit + Lit lit -> pure $ typeOfLit lit + Var name -> typeOfVar name + Undefined ty -> pure ty + _ -> error $ printf "unsupported val %s" (show $ pretty val) diff --git a/grin/src/Reducer/ExtendedSyntax/LLVM/PrimOps.hs b/grin/src/Reducer/ExtendedSyntax/LLVM/PrimOps.hs new file mode 100644 index 00000000..d31fd390 --- /dev/null +++ b/grin/src/Reducer/ExtendedSyntax/LLVM/PrimOps.hs @@ -0,0 +1,102 @@ +{-# LANGUAGE OverloadedStrings #-} +module Reducer.ExtendedSyntax.LLVM.PrimOps where + +import Control.Monad (when) +import LLVM.AST +import qualified LLVM.AST.IntegerPredicate as I +import qualified LLVM.AST.FloatingPointPredicate as F +import qualified LLVM.AST.CallingConvention as CC +import LLVM.AST.Type as LLVM +import LLVM.AST.AddrSpace +import qualified LLVM.AST.Constant as C + +import qualified Grin.ExtendedSyntax.Grin as Grin +import Grin.ExtendedSyntax.TypeEnv hiding (function) +import Reducer.ExtendedSyntax.LLVM.Base +import Reducer.ExtendedSyntax.LLVM.TypeGen +import Grin.ExtendedSyntax.PrimOpsPrelude + + +cgUnit = toCGType $ T_SimpleType T_Unit :: CGType +cgInt64 = toCGType $ T_SimpleType T_Int64 :: CGType +cgWord64 = toCGType $ T_SimpleType T_Word64 :: CGType +cgFloat = toCGType $ T_SimpleType T_Float :: CGType +cgBool = toCGType $ T_SimpleType T_Bool :: CGType +cgString = toCGType $ T_SimpleType T_String :: CGType +cgChar = toCGType $ T_SimpleType T_Char :: CGType + +codeExternal :: Grin.External -> [Operand] -> CG Result +codeExternal e ops = case Grin.eKind e of + Grin.PrimOp -> codeGenPrimOp (Grin.eName e) ops + Grin.FFI -> codeGenFFI e ops + +codeGenPrimOp :: Grin.Name -> [Operand] -> CG Result +codeGenPrimOp name [opA, opB] = pure $ case name of + -- Int + "_prim_int_add" -> I cgInt64 $ Add {nsw=False, nuw=False, operand0=opA, operand1=opB, metadata=[]} + "_prim_int_sub" -> I cgInt64 $ Sub {nsw=False, nuw=False, operand0=opA, operand1=opB, metadata=[]} + "_prim_int_mul" -> I cgInt64 $ Mul {nsw=False, nuw=False, operand0=opA, operand1=opB, metadata=[]} + "_prim_int_div" -> I cgInt64 $ SDiv {exact=False, operand0=opA, operand1=opB, metadata=[]} + "_prim_int_ashr" -> I cgInt64 $ AShr {exact=False, operand0=opA, operand1=opB, metadata=[]} + "_prim_int_eq" -> I cgBool $ ICmp {iPredicate=I.EQ, operand0=opA, operand1=opB, metadata=[]} + "_prim_int_ne" -> I cgBool $ ICmp {iPredicate=I.NE, operand0=opA, operand1=opB, metadata=[]} + "_prim_int_gt" -> I cgBool $ ICmp {iPredicate=I.SGT, operand0=opA, operand1=opB, metadata=[]} + "_prim_int_ge" -> I cgBool $ ICmp {iPredicate=I.SGE, operand0=opA, operand1=opB, metadata=[]} + "_prim_int_lt" -> I cgBool $ ICmp {iPredicate=I.SLT, operand0=opA, operand1=opB, metadata=[]} + "_prim_int_le" -> I cgBool $ ICmp {iPredicate=I.SLE, operand0=opA, operand1=opB, metadata=[]} + + -- Word + "_prim_word_add" -> I cgWord64 $ Add {nsw=False, nuw=False, operand0=opA, operand1=opB, metadata=[]} + "_prim_word_sub" -> I cgWord64 $ Sub {nsw=False, nuw=False, operand0=opA, operand1=opB, metadata=[]} + "_prim_word_mul" -> I cgWord64 $ Mul {nsw=False, nuw=False, operand0=opA, operand1=opB, metadata=[]} + "_prim_word_div" -> I cgWord64 $ UDiv {exact=False, operand0=opA, operand1=opB, metadata=[]} + "_prim_word_eq" -> I cgBool $ ICmp {iPredicate=I.EQ, operand0=opA, operand1=opB, metadata=[]} + "_prim_word_ne" -> I cgBool $ ICmp {iPredicate=I.NE, operand0=opA, operand1=opB, metadata=[]} + "_prim_word_gt" -> I cgBool $ ICmp {iPredicate=I.UGT, operand0=opA, operand1=opB, metadata=[]} + "_prim_word_ge" -> I cgBool $ ICmp {iPredicate=I.UGE, operand0=opA, operand1=opB, metadata=[]} + "_prim_word_lt" -> I cgBool $ ICmp {iPredicate=I.ULT, operand0=opA, operand1=opB, metadata=[]} + "_prim_word_le" -> I cgBool $ ICmp {iPredicate=I.ULE, operand0=opA, operand1=opB, metadata=[]} + + -- Float + "_prim_float_add" -> I cgFloat $ FAdd {fastMathFlags=noFastMathFlags, operand0=opA, operand1=opB, metadata=[]} + "_prim_float_sub" -> I cgFloat $ FSub {fastMathFlags=noFastMathFlags, operand0=opA, operand1=opB, metadata=[]} + "_prim_float_mul" -> I cgFloat $ FMul {fastMathFlags=noFastMathFlags, operand0=opA, operand1=opB, metadata=[]} + "_prim_float_div" -> I cgFloat $ FDiv {fastMathFlags=noFastMathFlags, operand0=opA, operand1=opB, metadata=[]} + "_prim_float_eq" -> I cgBool $ FCmp {fpPredicate=F.OEQ, operand0=opA, operand1=opB, metadata=[]} + "_prim_float_ne" -> I cgBool $ FCmp {fpPredicate=F.ONE, operand0=opA, operand1=opB, metadata=[]} + "_prim_float_gt" -> I cgBool $ FCmp {fpPredicate=F.OGT, operand0=opA, operand1=opB, metadata=[]} + "_prim_float_ge" -> I cgBool $ FCmp {fpPredicate=F.OGE, operand0=opA, operand1=opB, metadata=[]} + "_prim_float_lt" -> I cgBool $ FCmp {fpPredicate=F.OLT, operand0=opA, operand1=opB, metadata=[]} + "_prim_float_le" -> I cgBool $ FCmp {fpPredicate=F.OLE, operand0=opA, operand1=opB, metadata=[]} + + -- Bool + "_prim_bool_eq" -> I cgBool $ ICmp {iPredicate=I.EQ, operand0=opA, operand1=opB, metadata=[]} + "_prim_bool_ne" -> I cgBool $ ICmp {iPredicate=I.NE, operand0=opA, operand1=opB, metadata=[]} + + _ -> error $ "unknown primop: " ++ show name + +codeGenFFI :: Grin.External -> [Operand] -> CG Result +codeGenFFI e ops = do + if (length ops /= length (Grin.eArgsType e)) + then error $ "Non saturated function call: " ++ show (e, ops) + else mkFunction (Grin.nameString $ Grin.eName e) (ops `zip` (Grin.eArgsType e)) (Grin.eRetType e) + +mkFunction name ops_params_ty ret_ty = pure . I (tyToCGType ret_ty) $ Call + { tailCallKind = Nothing + , callingConvention = CC.C + , returnAttributes = [] + , function = Right $ ConstantOperand $ C.GlobalReference (fun (tyToLLVMType ret_ty) (tyToLLVMType <$> params_ty)) (mkName name) + , arguments = ops `zip` repeat [] + , functionAttributes = [] + , metadata = [] + } + where + (ops, params_ty) = unzip ops_params_ty + tyToLLVMType t = case t of + Grin.TySimple st -> typeGenSimpleType st + _ -> error $ "Non simple type in: " ++ show (name, t) + tyToCGType t = case t of + Grin.TySimple st -> toCGType (T_SimpleType st) + _ -> error $ "Non simple type in: " ++ show (name, t) + fptr ty = PointerType { pointerReferent = ty, pointerAddrSpace = AddrSpace 0} + fun ret args = fptr FunctionType {resultType = ret, argumentTypes = args, isVarArg = False} diff --git a/grin/src/Reducer/ExtendedSyntax/LLVM/TypeGen.hs b/grin/src/Reducer/ExtendedSyntax/LLVM/TypeGen.hs new file mode 100644 index 00000000..d0c4a231 --- /dev/null +++ b/grin/src/Reducer/ExtendedSyntax/LLVM/TypeGen.hs @@ -0,0 +1,222 @@ +{-# LANGUAGE LambdaCase, TupleSections, RecordWildCards, OverloadedStrings, TemplateHaskell #-} + +module Reducer.ExtendedSyntax.LLVM.TypeGen where + +import Text.Printf + +import Data.Word +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe) +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Vector (Vector) +import qualified Data.Vector as V +import qualified Data.List as List +import qualified Data.Foldable + +import Control.Monad.State +import Lens.Micro.Platform + +import LLVM.AST as AST hiding (Type, void) +import LLVM.AST.Constant as C hiding (Add, ICmp) +import LLVM.AST.Type hiding (Type, void) +import qualified LLVM.AST.Type as LLVM + +import Reducer.ExtendedSyntax.LLVM.Base +import Grin.ExtendedSyntax.Grin as Grin +import Grin.ExtendedSyntax.TypeEnv +import Grin.ExtendedSyntax.Pretty + +stringStructType :: LLVM.Type +stringStructType = LLVM.StructureType False [ptr i8, i64] + +stringType :: LLVM.Type +stringType = ptr stringStructType + +typeGenSimpleType :: SimpleType -> LLVM.Type +typeGenSimpleType = \case + T_Int64 -> i64 + T_Word64 -> i64 + T_Float -> float + T_Bool -> i1 + T_String -> stringType + T_Char -> i8 + T_Unit -> LLVM.void + T_Location _ -> locationLLVMType + T_UnspecifiedLocation -> locationLLVMType + T_Dead -> error $ "Dead/unused type was given." + +locationCGType :: CGType +locationCGType = toCGType $ T_SimpleType $ T_Location [] + +tagCGType :: CGType +tagCGType = toCGType $ T_SimpleType $ T_Int64 + +unitCGType :: CGType +unitCGType = toCGType $ T_SimpleType $ T_Unit + +voidLLVMType :: LLVM.Type +voidLLVMType = LLVM.void + +data TUBuild + = TUBuild + { tubStructIndexMap :: Map LLVM.Type Word32 + , tubArraySizeMap :: Map LLVM.Type Word32 + , tubArrayPosMap :: Map LLVM.Type Word32 + } + +emptyTUBuild = TUBuild mempty mempty mempty + +type TU = State TUBuild + +taggedUnion :: NodeSet -> TaggedUnion +taggedUnion ns = TaggedUnion (tuLLVMType tub) tuMapping where + + mapNode :: Vector SimpleType -> TU (Vector TUIndex) + mapNode v = do + nodeMapping <- mapM allocIndex v + modify $ \tub@TUBuild{..} -> tub {tubArraySizeMap = Map.unionWith max tubArrayPosMap tubArraySizeMap, tubArrayPosMap = mempty} + pure nodeMapping + + getStructIndex :: LLVM.Type -> TU Word32 + getStructIndex ty = state $ \tub@TUBuild{..} -> + let i = Map.findWithDefault (fromIntegral $ Map.size tubStructIndexMap) ty tubStructIndexMap + in (i, tub {tubStructIndexMap = Map.insert ty i tubStructIndexMap}) + + getArrayIndex :: LLVM.Type -> TU Word32 + getArrayIndex ty = state $ \tub@TUBuild{..} -> + let i = Map.findWithDefault 0 ty tubArrayPosMap + in (i, tub {tubArrayPosMap = Map.insert ty (succ i) tubArrayPosMap}) + + allocIndex :: SimpleType -> TU TUIndex + allocIndex sTy = TUIndex <$> getStructIndex t <*> getArrayIndex t <*> pure t where t = typeGenSimpleType sTy + + (tuMapping, tub) = runState (mapM mapNode ns) emptyTUBuild + + tuLLVMType TUBuild{..} = StructureType + { isPacked = True + , elementTypes = tagLLVMType : + [ ArrayType (fromIntegral $ Map.findWithDefault undefined ty tubArraySizeMap) ty + | (ty, _idx) <- List.sortBy (\(_,a) (_,b) -> compare a b) $ Map.toList tubStructIndexMap + ] + } + +isCompatibleTaggedUnion :: TaggedUnion -> TaggedUnion -> Bool +isCompatibleTaggedUnion (TaggedUnion tuLLVMTypeA tuMappingA) (TaggedUnion tuLLVMTypeB tuMappingB) + = tuLLVMTypeA == tuLLVMTypeB && Data.Foldable.and (Map.intersectionWith (==) tuMappingA tuMappingB) + +copyTaggedUnion :: Operand -> TaggedUnion -> TaggedUnion -> CG Operand +copyTaggedUnion srcVal srcTU dstTU | isCompatibleTaggedUnion srcTU dstTU = pure srcVal +copyTaggedUnion srcVal srcTU dstTU = do + let -- calculate mapping + mapping :: [(TUIndex, TUIndex)] -- src dst + mapping = concat . map V.toList . Map.elems $ Map.intersectionWith V.zip (tuMapping srcTU) (tuMapping dstTU) + validatedMapping = fst $ foldl validate mempty mapping + validate (l,m) x@(src, dst) = case Map.lookup dst m of + Nothing -> ((x:l), Map.insert dst src m) + Just prevSrc | prevSrc == src && tuItemLLVMType src == tuItemLLVMType dst -> (l,m) + | otherwise -> error $ printf "invalid tagged union mapping: %s" (show mapping) + -- set node items + build agg (itemType, srcIndex, dstIndex) = do + item <- codeGenLocalVar "src" itemType $ AST.ExtractValue + { aggregate = srcVal + , indices' = srcIndex + , metadata = [] + } + codeGenLocalVar "dst" dstTULLVMType $ AST.InsertValue + { aggregate = agg + , element = item + , indices' = dstIndex + , metadata = [] + } + tagIndex = [0] + dstTULLVMType = tuLLVMType dstTU + agg0 = undef dstTULLVMType + foldM build agg0 $ (tagLLVMType, tagIndex,tagIndex) : + [ ( tuItemLLVMType src + , [1 + tuStructIndex src, tuArrayIndex src] + , [1 + tuStructIndex dst, tuArrayIndex dst] + ) + | (src,dst) <- validatedMapping + ] + +codeGenExtractTag :: Operand -> CG Operand +codeGenExtractTag tuVal = do + codeGenLocalVar "tag" tagLLVMType $ AST.ExtractValue + { aggregate = tuVal + , indices' = [0] -- tag index + , metadata = [] + } + +codeGenBitCast :: Grin.Name -> Operand -> LLVM.Type -> CG Operand +codeGenBitCast name value dstType = do + codeGenLocalVar name dstType $ AST.BitCast + { operand0 = value + , type' = dstType + , metadata = [] + } + +{- + NEW approach: everything is tagged union + + compilation: + if type sets does not match then convert them +-} + +codeGenValueConversion :: CGType -> Operand -> CGType -> CG Operand +codeGenValueConversion srcCGType srcOp dstCGType = case srcCGType of + CG_SimpleType{} | srcCGType == dstCGType -> pure srcOp + _ | isLocation srcCGType && isLocation dstCGType -> pure srcOp + _ -> copyTaggedUnion srcOp (cgTaggedUnion srcCGType) (cgTaggedUnion dstCGType) + where isLocation = \case + CG_SimpleType{cgType = T_SimpleType T_Location{}} -> True + CG_SimpleType{cgType = T_SimpleType T_UnspecifiedLocation} -> True + _ -> False + +commonCGType :: [CGType] -> CGType +commonCGType tys | Just ty <- foldM joinSimpleType (head tys) tys = ty where + joinSimpleType :: CGType -> CGType -> Maybe CGType + joinSimpleType t@(CG_SimpleType l1 (T_SimpleType t1)) (CG_SimpleType l2 (T_SimpleType t2)) | l1 == l2 = case (t1, t2) of + -- join locations + (T_Location p1, T_Location p2) -> Just . CG_SimpleType l1 . T_SimpleType $ T_Location (List.nub $ p1 ++ p2) + _ | t1 == t1 -> Just t + | otherwise -> Nothing + joinSimpleType _ _ = Nothing + +commonCGType tys | all isNodeSet tys = toCGType $ T_NodeSet $ mconcat [ns | CG_NodeSet _ (T_NodeSet ns) _ <- tys] where + isNodeSet = \case + CG_NodeSet{} -> True + _ -> False +commonCGType tys = error $ printf "no common type for %s" (show $ pretty $ map cgType tys) + +toCGType :: Type -> CGType +toCGType t = case t of + T_SimpleType sTy -> CG_SimpleType (typeGenSimpleType sTy) t + T_NodeSet ns -> CG_NodeSet (tuLLVMType tu) t tu where tu = taggedUnion ns + +getVarType :: Grin.Name -> CG CGType +getVarType name = do + TypeEnv{..} <- gets _envTypeEnv + pure $ maybe (error ("unknown variable " ++ unpackName name)) toCGType + $ Map.lookup name _variable + +getFunctionType :: Grin.Name -> CG (CGType, [CGType]) +getFunctionType name = do + TypeEnv{..} <- gets _envTypeEnv + case Map.lookup name _function of + Nothing -> error $ printf "unknown function %s" name + Just (retValue, argValues) -> do + retType <- pure $ toCGType retValue + argTypes <- pure $ map toCGType $ V.toList argValues + pure (retType, argTypes) + +getTagId :: Tag -> CG Constant +getTagId tag = do + tagMap <- use envTagMap + case Map.lookup tag tagMap of + Just c -> pure c + Nothing -> do + let c = Int 64 $ fromIntegral $ Map.size tagMap + envTagMap %= (Map.insert tag c) + pure c