Skip to content

Commit

Permalink
Extend protocol generator for new_id arguments in all argument positions
Browse files Browse the repository at this point in the history
Previously only the first argument was allowed to have type new_id,
which was an incorrect assumption.

Enables support for (at least) the following protocols:
- wp_presentation
- input-method-unstable-v2
  • Loading branch information
queezle42 committed Jun 6, 2024
1 parent 7eb19ea commit 11f886e
Showing 1 changed file with 37 additions and 26 deletions.
63 changes: 37 additions & 26 deletions quasar-wayland/src/Quasar/Wayland/Protocol/TH.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module Quasar.Wayland.Protocol.TH (
import Control.Monad (mapAndUnzipM)
import Control.Monad.Writer
import Data.ByteString qualified as BS
import Data.List (intersperse, singleton)
import Data.List (intersperse, singleton, find)
import Data.Void (absurd)
import GHC.Records
import Language.Haskell.TH
Expand Down Expand Up @@ -301,23 +301,29 @@ messageProxyInstanceDecs side messageContexts = mapM messageProxyInstanceD messa
returnT :: Q Type
returnT = maybe [t|()|] (argumentType side) (proxyReturnArgument msg.msgSpec)
applyArgTypes :: Q Type -> Q Type
applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType side <$> args)
applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType side <$> proxyArgs)

args :: [ArgumentSpec]
args = proxyArguments msg.msgSpec
proxyArgs :: [ArgumentSpec]
proxyArgs = filterProxyArguments msg.msgSpec

actionE :: Q Exp
actionE
| msg.msgSpec.isConstructor = ctorE
| msg.msgSpec.isDestructor = dtorE
| otherwise = normalE

-- Constructor: the first argument becomes the return value
-- Constructor: the argument with type new_id becomes the return value
ctorE :: Q Exp
ctorE = [|newObject Nothing (objectVersion $objectE) >>= \(newObj, newId) -> newObj <$ (sendMessage object =<< $(msgE [|pure newId|]))|]
where
msgE :: Q Exp -> Q Exp
msgE idArgE = mkWireMsgE (idArgE : (wireArgE <$> args))
msgE idArgE = mkWireMsgE do
-- Walk msgSpec arguments, which include the new_id argument
msg.msgSpec.arguments <&> \arg ->
if isNewId arg.argType
-- Inject the new_id at the correct position
then idArgE
else wireArgE arg

dtorE :: Q Exp
dtorE = [|handleDestructor object $normalE|]
Expand All @@ -327,7 +333,7 @@ messageProxyInstanceDecs side messageContexts = mapM messageProxyInstanceD messa
normalE = [|sendMessage object =<< $(msgE)|]
where
msgE :: Q Exp
msgE = mkWireMsgE (wireArgE <$> args)
msgE = mkWireMsgE (wireArgE <$> proxyArgs)

mkWireMsgE :: [Q Exp] -> Q Exp
mkWireMsgE mkWireArgEs = applyA (conE msg.msgConName) mkWireArgEs
Expand All @@ -336,17 +342,22 @@ messageProxyInstanceDecs side messageContexts = mapM messageProxyInstanceD messa
wireArgE arg = toWireArgument arg.argType (msgArgE msg arg)

toWireArgument :: ArgumentType -> Q Exp -> Q Exp
toWireArgument (ObjectArgument _) objectE = [|objectWireArgument $objectE|]
toWireArgument (NullableObjectArgument _) objectE = [|nullableObjectWireArgument $objectE|]
toWireArgument (ObjectArgument _) oe = [|objectWireArgument $oe|]
toWireArgument (NullableObjectArgument _) oe = [|nullableObjectWireArgument $oe|]
toWireArgument (NewIdArgument _) _ = unreachableCodePath -- The specification parser has a check to prevent this
toWireArgument _ x = [|pure $x|]

proxyArguments :: MessageSpec -> [ArgumentSpec]
proxyArguments msg = (if msg.isConstructor then drop 1 else id) msg.arguments
filterProxyArguments :: MessageSpec -> [ArgumentSpec]
filterProxyArguments msg =
if msg.isConstructor
then filter (not . isNewId . (.argType)) msg.arguments
else msg.arguments

proxyReturnArgument :: MessageSpec -> Maybe ArgumentSpec
proxyReturnArgument msg@MessageSpec{arguments=(firstArg:_)} = if msg.isConstructor then Just firstArg else Nothing
proxyReturnArgument _ = Nothing
proxyReturnArgument msg@MessageSpec{arguments} =
if msg.isConstructor
then find (isNewId . (.argType)) arguments
else Nothing


messageFieldName :: MessageContext -> Name
Expand Down Expand Up @@ -414,7 +425,7 @@ msgArgPats msg = varP . msgArgTempName <$> msg.msgSpec.arguments

-- | Pattern to match all arguments of a message (for a proxy). Arguments can then be accessed by using e.g. 'msgArgE'.
msgProxyArgPats :: MessageContext -> [Q Pat]
msgProxyArgPats msg = varP . msgArgTempName <$> proxyArguments msg.msgSpec
msgProxyArgPats msg = varP . msgArgTempName <$> filterProxyArguments msg.msgSpec

-- | Expression for accessing a message argument which has been matched from a request/event using 'msgArgConP'.
msgArgE :: MessageContext -> ArgumentSpec -> Q Exp
Expand All @@ -437,7 +448,7 @@ messageTypeDecs name msgs = execWriterT do
messageTypeD :: Q Dec
messageTypeD = dataD (pure []) name [] Nothing (con <$> msgs) []
con :: MessageContext -> Q Con
con msg = normalC (msg.msgConName) (conField <$> msg.msgSpec.arguments)
con msg = normalC msg.msgConName (conField <$> msg.msgSpec.arguments)
where
conField :: ArgumentSpec -> Q BangType
conField arg = defaultBangType (argumentWireType arg)
Expand All @@ -459,16 +470,16 @@ isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD,
opcodeNameD :: Q Dec
opcodeNameD = funD 'opcodeName ((opcodeNameClause <$> msgs) <> [opcodeNameInvalidClause])
opcodeNameClause :: MessageContext -> Q Clause
opcodeNameClause msg = clause [litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB ([|Just $(stringE msg.msgSpec.name)|])) []
opcodeNameClause msg = clause [litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB [|Just $(stringE msg.msgSpec.name)|]) []
opcodeNameInvalidClause :: Q Clause
opcodeNameInvalidClause = clause [wildP] (normalB ([|Nothing|])) []
opcodeNameInvalidClause = clause [wildP] (normalB [|Nothing|]) []
getMessageD :: Q Dec
getMessageD = funD 'getMessage ((getMessageClause <$> msgs) <> [getMessageInvalidOpcodeClause])
getMessageClause :: MessageContext -> Q Clause
getMessageClause msg = clause [wildP, litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB getMessageE) []
where
getMessageE :: Q Exp
getMessageE = applyALifted (conE (msg.msgConName)) ((\argT -> [|getArgument @($argT)|]) . argumentWireType <$> msg.msgSpec.arguments)
getMessageE = applyALifted (conE msg.msgConName) ((\argT -> [|getArgument @($argT)|]) . argumentWireType <$> msg.msgSpec.arguments)
getMessageInvalidOpcodeClause :: Q Clause
getMessageInvalidOpcodeClause = do
let object = mkName "object"
Expand Down Expand Up @@ -628,22 +639,22 @@ parseMessage _isRequest interface (opcode, element) = do
Just "destructor" -> pure True
Just messageType -> fail $ "Unknown message type: " <> messageType

let isRegistryBind = interface == "wl_registry" && name == "bind"

forM_ arguments \arg -> do
when
do arg.argType == GenericNewIdArgument && (interface /= "wl_registry" || name /= "bind")
do arg.argType == GenericNewIdArgument && not isRegistryBind
do fail $ "Invalid \"new_id\" argument without \"interface\" attribute encountered on " <> loc <> " (only valid on wl_registry.bind)"
when
do arg.argType == GenericObjectArgument && (interface /= "wl_display" || name /= "error")
do fail $ "Invalid \"object\" argument without \"interface\" attribute encountered on " <> loc <> " (only valid on wl_display.error)"

isConstructor <- case arguments of
let newIdArguments = filter (isNewId . (.argType)) arguments

isConstructor <- case newIdArguments of
[] -> pure False
(firstArg:otherArgs) -> do
when
do any (isNewId . (.argType)) otherArgs && not (interface == "wl_registry" && name == "bind")
-- TODO incorrect assumption, needs to be supported for wp_presentation
do fail $ "Message uses NewId in unsupported position on: " <> loc <> " (NewId currently has to be the first argument, which is a parser bug)"
pure (isNewId firstArg.argType)
[_] -> pure (not isRegistryBind)
_ -> fail $ "Invalid wayland message specification: message has multiple NewId arguments at: " <> loc

pure MessageSpec {
name,
Expand Down

0 comments on commit 11f886e

Please sign in to comment.