diff --git a/src/Pinch/Generate.hs b/src/Pinch/Generate.hs index 9d0388c..f7f061a 100644 --- a/src/Pinch/Generate.hs +++ b/src/Pinch/Generate.hs @@ -8,7 +8,6 @@ import Control.Exception import Control.Monad.Reader import qualified Data.ByteString as BS import Data.Char -import Data.Foldable (forM_) import qualified Data.HashMap.Strict as Map import Data.List import Data.Maybe @@ -100,7 +99,9 @@ gProgram s inp (Program headers defs) = do (imports, tyMaps) <- unzip <$> traverse (gInclude s baseDir) incHeaders let tyMap = Map.unions tyMaps - let (typeDecls, clientDecls, serverDecls) = unzip3 $ runReader (traverse gDefinition defs) $ Context tyMap s + let (typeDecls, clientDecls, serverDecls, serverImports) = unzip4 $ + runReader (traverse gDefinition defs) $ + Context tyMap s headers let mkMod suffix = H.Module (H.ModuleName $ modBaseName <> suffix) [ H.PragmaLanguage "TypeFamilies, DeriveGeneric, TypeApplications, OverloadedStrings" , H.PragmaOptsGhc "-w" ] @@ -123,7 +124,7 @@ gProgram s inp (Program headers defs) = do mkMod ".Server" ( [ impTypes , H.ImportDecl (H.ModuleName "Pinch.Server") True H.IEverything - ] ++ imports ++ defaultImports) + ] ++ imports ++ concat serverImports ++ defaultImports) (concat serverDecls) ] @@ -140,7 +141,6 @@ gProgram s inp (Program headers defs) = do , H.ImportDecl (H.ModuleName "Control.Applicative") True H.IEverything , H.ImportDecl (H.ModuleName "Control.Exception") True H.IEverything , H.ImportDecl (H.ModuleName "Pinch") True H.IEverything - , H.ImportDecl (H.ModuleName "Pinch.Server") True H.IEverything , H.ImportDecl (H.ModuleName "Pinch.Internal.RPC") True H.IEverything , H.ImportDecl (H.ModuleName "Data.Text") True H.IEverything , H.ImportDecl (H.ModuleName "Data.ByteString") True H.IEverything @@ -159,6 +159,7 @@ data Context = Context { cModuleMap :: ModuleMap , cSettings :: Settings + , cHeaders :: [Header SourcePos] } type GenerateM = Reader Context @@ -171,10 +172,10 @@ gInclude s dir i = do let thriftModName = T.pack $ dropExtension $ T.unpack $ includePath i pure (H.ImportDecl modName True H.IEverything, Map.singleton thriftModName modName) -gDefinition :: Definition SourcePos -> GenerateM ([H.Decl], [H.Decl], [H.Decl]) +gDefinition :: Definition SourcePos -> GenerateM ([H.Decl], [H.Decl], [H.Decl], [H.ImportDecl]) gDefinition def = case def of - ConstDefinition c -> (\x -> (x, [], [])) <$> gConst c - TypeDefinition ty -> (\x -> (x, [], [])) <$> gType ty + ConstDefinition c -> (\x -> (x, [], [], [])) <$> gConst c + TypeDefinition ty -> (\x -> (x, [], [], [])) <$> gType ty ServiceDefinition s -> gService s gConst :: A.Const SourcePos -> GenerateM [H.Decl] @@ -439,28 +440,51 @@ gField prefix (i, f) = do pure (index, prefix <> "_" <> fieldName f, ty, req) -gService :: Service SourcePos -> GenerateM ([H.Decl], [H.Decl], [H.Decl]) +gService :: Service SourcePos -> GenerateM ([H.Decl], [H.Decl], [H.Decl], [H.ImportDecl]) gService s = do + headers <- asks cHeaders + settings <- asks cSettings (nms, tys, handlers, calls, tyDecls) <- unzip5 <$> traverse gFunction (serviceFunctions s) + + let (additionalImports, baseService, baseFunction) = case serviceExtends s of + Just baseServiceIdentifier -> do + case T.splitOn "." baseServiceIdentifier of + [importSource, baseServiceName] -> do + let importModule = (getModuleName settings headers $ T.unpack importSource) <> ".Server" + ([importModule], [("baseServer", H.TyCon $ importModule <> "." <> baseServiceName)], ".functions_" <> baseServiceName) + _ -> ([], [], "") + Nothing -> ([], [], "") + let extensionFunction = case additionalImports of + [] -> "" + imports -> head imports <> baseFunction <> " (baseServer server) `Data.HashMap.Strict.union` " let serverDecls = - [ H.DataDecl serviceTyName [ H.RecConDecl serviceConName $ zip nms tys ] [] + [ H.DataDecl serviceTyName [ H.RecConDecl serviceConName $ baseService <> zip nms tys ] [] + , H.TypeSigDecl + ("functions_" <> serviceConName) + ( H.TyLam + [H.TyCon serviceConName] + (H.TyCon "Data.HashMap.Strict.HashMap Data.Text.Text Pinch.Server.Handler") + ) + , H.FunBind + [ H.Match ("functions_" <> serviceConName) [H.PVar "server"] + ( H.EApp (H.EVar (extensionFunction <> "Data.HashMap.Strict.fromList")) [ H.EList handlers ] ) + ] , H.TypeSigDecl (prefix <> "_mkServer") (H.TyLam [H.TyCon serviceConName] (H.TyCon "Pinch.Server.ThriftServer")) , H.FunBind [ H.Match (prefix <> "_mkServer") [H.PVar "server"] - ( H.ELet "functions" - (H.EApp "Data.HashMap.Strict.fromList" [ H.EList handlers ] ) - ( H.EApp "Pinch.Server.createServer" - [ (H.ELam ["nm"] - (H.EApp "Data.HashMap.Strict.lookup" - [ "nm", "functions" ] - ) + ( H.EApp "Pinch.Server.createServer" + [ (H.ELam ["nm"] + (H.EApp "Data.HashMap.Strict.lookup" + [ "nm", H.EVar $ "functions_" <> serviceConName <> " server" ] ) - ] - ) + ) + ] ) ] ] - pure (concat tyDecls, concat calls, serverDecls) + let serverImports = (\imp -> H.ImportDecl (H.ModuleName imp) True H.IEverything) <$> additionalImports + + pure (concat tyDecls, concat calls, serverDecls, serverImports) where serviceTyName = capitalize $ serviceName s serviceConName = capitalize $ serviceName s