diff --git a/quasar-wayland/src/Quasar/Wayland/Protocol/TH.hs b/quasar-wayland/src/Quasar/Wayland/Protocol/TH.hs index 7a5d295..c7cb386 100644 --- a/quasar-wayland/src/Quasar/Wayland/Protocol/TH.hs +++ b/quasar-wayland/src/Quasar/Wayland/Protocol/TH.hs @@ -8,7 +8,7 @@ module Quasar.Wayland.Protocol.TH ( import Control.Monad (mapAndUnzipM) import Control.Monad.Writer import Data.ByteString qualified as BS -import Data.List (intersperse, singleton) +import Data.List (intersperse, singleton, find) import Data.Void (absurd) import GHC.Records import Language.Haskell.TH @@ -301,10 +301,10 @@ messageProxyInstanceDecs side messageContexts = mapM messageProxyInstanceD messa returnT :: Q Type returnT = maybe [t|()|] (argumentType side) (proxyReturnArgument msg.msgSpec) applyArgTypes :: Q Type -> Q Type - applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType side <$> args) + applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType side <$> proxyArgs) - args :: [ArgumentSpec] - args = proxyArguments msg.msgSpec + proxyArgs :: [ArgumentSpec] + proxyArgs = filterProxyArguments msg.msgSpec actionE :: Q Exp actionE @@ -312,12 +312,18 @@ messageProxyInstanceDecs side messageContexts = mapM messageProxyInstanceD messa | msg.msgSpec.isDestructor = dtorE | otherwise = normalE - -- Constructor: the first argument becomes the return value + -- Constructor: the argument with type new_id becomes the return value ctorE :: Q Exp ctorE = [|newObject Nothing (objectVersion $objectE) >>= \(newObj, newId) -> newObj <$ (sendMessage object =<< $(msgE [|pure newId|]))|] where msgE :: Q Exp -> Q Exp - msgE idArgE = mkWireMsgE (idArgE : (wireArgE <$> args)) + msgE idArgE = mkWireMsgE do + -- Walk msgSpec arguments, which include the new_id argument + msg.msgSpec.arguments <&> \arg -> + if isNewId arg.argType + -- Inject the new_id at the correct position + then idArgE + else wireArgE arg dtorE :: Q Exp dtorE = [|handleDestructor object $normalE|] @@ -327,7 +333,7 @@ messageProxyInstanceDecs side messageContexts = mapM messageProxyInstanceD messa normalE = [|sendMessage object =<< $(msgE)|] where msgE :: Q Exp - msgE = mkWireMsgE (wireArgE <$> args) + msgE = mkWireMsgE (wireArgE <$> proxyArgs) mkWireMsgE :: [Q Exp] -> Q Exp mkWireMsgE mkWireArgEs = applyA (conE msg.msgConName) mkWireArgEs @@ -336,17 +342,22 @@ messageProxyInstanceDecs side messageContexts = mapM messageProxyInstanceD messa wireArgE arg = toWireArgument arg.argType (msgArgE msg arg) toWireArgument :: ArgumentType -> Q Exp -> Q Exp - toWireArgument (ObjectArgument _) objectE = [|objectWireArgument $objectE|] - toWireArgument (NullableObjectArgument _) objectE = [|nullableObjectWireArgument $objectE|] + toWireArgument (ObjectArgument _) oe = [|objectWireArgument $oe|] + toWireArgument (NullableObjectArgument _) oe = [|nullableObjectWireArgument $oe|] toWireArgument (NewIdArgument _) _ = unreachableCodePath -- The specification parser has a check to prevent this toWireArgument _ x = [|pure $x|] -proxyArguments :: MessageSpec -> [ArgumentSpec] -proxyArguments msg = (if msg.isConstructor then drop 1 else id) msg.arguments +filterProxyArguments :: MessageSpec -> [ArgumentSpec] +filterProxyArguments msg = + if msg.isConstructor + then filter (not . isNewId . (.argType)) msg.arguments + else msg.arguments proxyReturnArgument :: MessageSpec -> Maybe ArgumentSpec -proxyReturnArgument msg@MessageSpec{arguments=(firstArg:_)} = if msg.isConstructor then Just firstArg else Nothing -proxyReturnArgument _ = Nothing +proxyReturnArgument msg@MessageSpec{arguments} = + if msg.isConstructor + then find (isNewId . (.argType)) arguments + else Nothing messageFieldName :: MessageContext -> Name @@ -414,7 +425,7 @@ msgArgPats msg = varP . msgArgTempName <$> msg.msgSpec.arguments -- | Pattern to match all arguments of a message (for a proxy). Arguments can then be accessed by using e.g. 'msgArgE'. msgProxyArgPats :: MessageContext -> [Q Pat] -msgProxyArgPats msg = varP . msgArgTempName <$> proxyArguments msg.msgSpec +msgProxyArgPats msg = varP . msgArgTempName <$> filterProxyArguments msg.msgSpec -- | Expression for accessing a message argument which has been matched from a request/event using 'msgArgConP'. msgArgE :: MessageContext -> ArgumentSpec -> Q Exp @@ -437,7 +448,7 @@ messageTypeDecs name msgs = execWriterT do messageTypeD :: Q Dec messageTypeD = dataD (pure []) name [] Nothing (con <$> msgs) [] con :: MessageContext -> Q Con - con msg = normalC (msg.msgConName) (conField <$> msg.msgSpec.arguments) + con msg = normalC msg.msgConName (conField <$> msg.msgSpec.arguments) where conField :: ArgumentSpec -> Q BangType conField arg = defaultBangType (argumentWireType arg) @@ -459,16 +470,16 @@ isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD, opcodeNameD :: Q Dec opcodeNameD = funD 'opcodeName ((opcodeNameClause <$> msgs) <> [opcodeNameInvalidClause]) opcodeNameClause :: MessageContext -> Q Clause - opcodeNameClause msg = clause [litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB ([|Just $(stringE msg.msgSpec.name)|])) [] + opcodeNameClause msg = clause [litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB [|Just $(stringE msg.msgSpec.name)|]) [] opcodeNameInvalidClause :: Q Clause - opcodeNameInvalidClause = clause [wildP] (normalB ([|Nothing|])) [] + opcodeNameInvalidClause = clause [wildP] (normalB [|Nothing|]) [] getMessageD :: Q Dec getMessageD = funD 'getMessage ((getMessageClause <$> msgs) <> [getMessageInvalidOpcodeClause]) getMessageClause :: MessageContext -> Q Clause getMessageClause msg = clause [wildP, litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB getMessageE) [] where getMessageE :: Q Exp - getMessageE = applyALifted (conE (msg.msgConName)) ((\argT -> [|getArgument @($argT)|]) . argumentWireType <$> msg.msgSpec.arguments) + getMessageE = applyALifted (conE msg.msgConName) ((\argT -> [|getArgument @($argT)|]) . argumentWireType <$> msg.msgSpec.arguments) getMessageInvalidOpcodeClause :: Q Clause getMessageInvalidOpcodeClause = do let object = mkName "object" @@ -628,22 +639,22 @@ parseMessage _isRequest interface (opcode, element) = do Just "destructor" -> pure True Just messageType -> fail $ "Unknown message type: " <> messageType + let isRegistryBind = interface == "wl_registry" && name == "bind" + forM_ arguments \arg -> do when - do arg.argType == GenericNewIdArgument && (interface /= "wl_registry" || name /= "bind") + do arg.argType == GenericNewIdArgument && not isRegistryBind do fail $ "Invalid \"new_id\" argument without \"interface\" attribute encountered on " <> loc <> " (only valid on wl_registry.bind)" when do arg.argType == GenericObjectArgument && (interface /= "wl_display" || name /= "error") do fail $ "Invalid \"object\" argument without \"interface\" attribute encountered on " <> loc <> " (only valid on wl_display.error)" - isConstructor <- case arguments of + let newIdArguments = filter (isNewId . (.argType)) arguments + + isConstructor <- case newIdArguments of [] -> pure False - (firstArg:otherArgs) -> do - when - do any (isNewId . (.argType)) otherArgs && not (interface == "wl_registry" && name == "bind") - -- TODO incorrect assumption, needs to be supported for wp_presentation - do fail $ "Message uses NewId in unsupported position on: " <> loc <> " (NewId currently has to be the first argument, which is a parser bug)" - pure (isNewId firstArg.argType) + [_] -> pure (not isRegistryBind) + _ -> fail $ "Invalid wayland message specification: message has multiple NewId arguments at: " <> loc pure MessageSpec { name,