From 481493fa3b00c9aa6fe49ecab97f6ee6b4a235c6 Mon Sep 17 00:00:00 2001 From: M Aswin Kishore <60577077+mak626@users.noreply.github.com> Date: Sun, 17 Mar 2024 18:23:25 +0530 Subject: [PATCH] refact: apply black linting --- docs/conf.py | 135 ++++++++--------- examples/fastapi/graphql.py | 22 ++- examples/fastapi/schema.py | 34 +++-- graphql_ws/contexts/async_context.py | 8 +- graphql_ws/contexts/context.py | 7 +- graphql_ws/integrations/fastapi/context.py | 8 +- graphql_ws/integrations/fastapi/server.py | 28 +++- .../message_types/connection_init.py | 1 - .../message_types/invalid.py | 4 +- .../message_types/subscribe.py | 4 +- .../protocols/messages/bi_directional.py | 2 +- .../protocols/messages/client_to_server.py | 3 +- .../protocols/messages/message_parser.py | 19 ++- graphql_ws/servers/async_server.py | 138 +++++++++++++----- graphql_ws/servers/server.py | 88 ++++++++--- .../async_subscription_manager.py | 50 ++++--- .../integrations/google_pubsub.py | 9 +- .../subscription_manager.py | 8 +- .../sync_subscription_manager.py | 34 +++-- setup.py | 13 +- 20 files changed, 403 insertions(+), 212 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index a14bd67..244eec2 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -20,7 +20,7 @@ # directory, add these directories to sys.path here. If the directory is # relative to the documentation root, use os.path.abspath to make it # absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('.')) # Get the project root dir, which is the parent dir of this cwd = os.getcwd() @@ -36,27 +36,27 @@ # -- General configuration --------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.viewcode"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'GraphQL AioWS' -copyright = u"2017, Syrus Akbary" +project = "GraphQL AioWS" +copyright = "2017, Syrus Akbary" # The version info for the project you're documenting, acts as replacement # for |version| and |release|, also used in various other places throughout @@ -69,126 +69,126 @@ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to # some non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built # documents. -#keep_warnings = False +# keep_warnings = False # -- Options for HTML output ------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = "default" # Theme options are theme-specific and customize the look and feel of a # theme further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as # html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the # top of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon # of the docs. This file should be a Windows icon file (.ico) being # 16x16 or 32x32 pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) # here, relative to this directory. They are copied after the builtin # static files, so a file named "default.css" will overwrite the builtin # "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # If not '', a 'Last updated on:' timestamp is inserted at every page # bottom, using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names # to template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. # Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. # Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages # will contain a tag referring to it. The value of this option # must be the base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'graphql_wsdoc' +htmlhelp_basename = "graphql_wsdoc" # -- Options for LaTeX output ------------------------------------------ @@ -196,10 +196,8 @@ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. #'preamble': '', } @@ -208,30 +206,34 @@ # (source start file, target name, title, author, documentclass # [howto/manual]). latex_documents = [ - ('index', 'graphql_ws.tex', - u'GraphQL AioWS Documentation', - u'Syrus Akbary', 'manual'), + ( + "index", + "graphql_ws.tex", + "GraphQL AioWS Documentation", + "Syrus Akbary", + "manual", + ), ] # The name of an image file (relative to this directory) to place at # the top of the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings # are parts, not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output ------------------------------------ @@ -239,13 +241,11 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'graphql_ws', - u'GraphQL AioWS Documentation', - [u'Syrus Akbary'], 1) + ("index", "graphql_ws", "GraphQL AioWS Documentation", ["Syrus Akbary"], 1) ] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ---------------------------------------- @@ -254,22 +254,25 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'graphql_ws', - u'GraphQL AioWS Documentation', - u'Syrus Akbary', - 'graphql_ws', - 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "graphql_ws", + "GraphQL AioWS Documentation", + "Syrus Akbary", + "graphql_ws", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False diff --git a/examples/fastapi/graphql.py b/examples/fastapi/graphql.py index 849b7f7..c3f0fc9 100644 --- a/examples/fastapi/graphql.py +++ b/examples/fastapi/graphql.py @@ -1,5 +1,5 @@ -from starlette.requests import Request from fastapi import APIRouter +from starlette.requests import Request from starlette.websockets import WebSocket from graphql_ws.integrations.fastapi.server import FastAPISubscriptionServer @@ -9,11 +9,17 @@ router = APIRouter(tags=["GraphQL Server"]) subscription_manager = AsyncSubscriptionManager() -subscription_server = FastAPISubscriptionServer(schema, subscription_manager=subscription_manager) +subscription_server = FastAPISubscriptionServer( + schema, subscription_manager=subscription_manager +) class Context: - def __init__(self, request: Request | WebSocket, subscription_manager_: AsyncSubscriptionManager) -> None: + def __init__( + self, + request: Request | WebSocket, + subscription_manager_: AsyncSubscriptionManager, + ) -> None: self.subscription_manager = subscription_manager_ @@ -25,13 +31,17 @@ async def graphql_post(request: Request): # noqa request_json.get("query"), operation_name=request_json.get("operationName"), variables=request_json.get("variables"), - context=Context(request=request, subscription_manager_=subscription_manager) + context=Context(request=request, subscription_manager_=subscription_manager), ) return result.formatted @router.websocket("", name="Subscription Endpoint") async def websocket_endpoint(websocket: WebSocket): - await subscription_server.handle(websocket, request_context=Context(request=websocket, - subscription_manager_=subscription_manager)) + await subscription_server.handle( + websocket, + request_context=Context( + request=websocket, subscription_manager_=subscription_manager + ), + ) return websocket diff --git a/examples/fastapi/schema.py b/examples/fastapi/schema.py index c09f44a..f948a32 100644 --- a/examples/fastapi/schema.py +++ b/examples/fastapi/schema.py @@ -1,5 +1,5 @@ -import random import asyncio +import random from typing import AsyncGenerator import graphene @@ -23,11 +23,12 @@ class Input: Output = UserCreatePayload # noinspection PyMethodMayBeStatic - async def mutate_and_get_payload(self, - info: GraphQLResolveInfo, name: str, client_mutation_id: str | None = None): - await info.context.subscription_manager.publish(topic="user_created", payload=UserObjectType( - name=name - )) + async def mutate_and_get_payload( + self, info: GraphQLResolveInfo, name: str, client_mutation_id: str | None = None + ): + await info.context.subscription_manager.publish( + topic="user_created", payload=UserObjectType(name=name) + ) # noinspection PyArgumentList return UserCreatePayload(client_mutation_id=client_mutation_id, name=name) @@ -50,16 +51,21 @@ class Subscription(graphene.ObjectType): count_seconds = graphene.Int(up_to=graphene.Int()) random_int = graphene.Field(RandomType) - async def subscribe_count_seconds(root, info: GraphQLResolveInfo, up_to: int = 100) -> AsyncGenerator[int, None]: + async def subscribe_count_seconds( + self, info: GraphQLResolveInfo, up_to: int = 100 + ) -> AsyncGenerator[int, None]: for i in range(up_to): yield i await asyncio.sleep(0.5) - async def subscribe_user_created(root, info: GraphQLResolveInfo) -> AsyncGenerator[str, None]: - return info.context.subscription_manager.subscribe(topic="user_created", - connection_context=info.context.connection_context) + async def subscribe_user_created( + self, info: GraphQLResolveInfo + ) -> AsyncGenerator[str, None]: + return info.context.subscription_manager.subscribe( + topic="user_created", connection_context=info.context.connection_context + ) - async def subscribe_random_int(root, info): + async def subscribe_random_int(self, info): i = 0 while True: yield RandomType(seconds=i, random_int=random.randint(0, 500)) @@ -67,4 +73,8 @@ async def subscribe_random_int(root, info): i += 1 -schema = graphene.Schema(query=Query, mutation=Mutation, subscription=Subscription, ) +schema = graphene.Schema( + query=Query, + mutation=Mutation, + subscription=Subscription, +) diff --git a/graphql_ws/contexts/async_context.py b/graphql_ws/contexts/async_context.py index fe9e3b1..729dd6c 100644 --- a/graphql_ws/contexts/async_context.py +++ b/graphql_ws/contexts/async_context.py @@ -8,8 +8,12 @@ # noinspection DuplicatedCode class AsyncConnectionContext(BaseConnectionContext): - def __init__(self, websocket: Any, request_context=None, - protocol: ProtocolEnum = ProtocolEnum.GRAPHQL_WS): + def __init__( + self, + websocket: Any, + request_context=None, + protocol: ProtocolEnum = ProtocolEnum.GRAPHQL_WS, + ): self.websocket = websocket super().__init__(websocket, request_context=request_context, protocol=protocol) diff --git a/graphql_ws/contexts/context.py b/graphql_ws/contexts/context.py index 340360a..eb8cd29 100644 --- a/graphql_ws/contexts/context.py +++ b/graphql_ws/contexts/context.py @@ -3,7 +3,12 @@ class BaseConnectionContext(object): - def __init__(self, websocket, request_context=None, protocol: ProtocolEnum = ProtocolEnum.GRAPHQL_WS): + def __init__( + self, + websocket, + request_context=None, + protocol: ProtocolEnum = ProtocolEnum.GRAPHQL_WS, + ): self.protocol = protocol self.websocket = websocket self.request_context = request_context diff --git a/graphql_ws/integrations/fastapi/context.py b/graphql_ws/integrations/fastapi/context.py index 4615a6e..2708f74 100644 --- a/graphql_ws/integrations/fastapi/context.py +++ b/graphql_ws/integrations/fastapi/context.py @@ -11,8 +11,12 @@ class FastAPIConnectionContext(AsyncConnectionContext): - def __init__(self, websocket: WebSocket, request_context=None, - protocol: ProtocolEnum = ProtocolEnum.GRAPHQL_WS): + def __init__( + self, + websocket: WebSocket, + request_context=None, + protocol: ProtocolEnum = ProtocolEnum.GRAPHQL_WS, + ): self.websocket = websocket self.protocol = protocol super().__init__(websocket, request_context=request_context, protocol=protocol) diff --git a/graphql_ws/integrations/fastapi/server.py b/graphql_ws/integrations/fastapi/server.py index e0f22a6..9adbfbd 100644 --- a/graphql_ws/integrations/fastapi/server.py +++ b/graphql_ws/integrations/fastapi/server.py @@ -10,18 +10,25 @@ class FastAPISubscriptionServer(AsyncSubscriptionServer): - def __init__(self, schema, subscription_manager: AsyncSubscriptionManager | None = None): + def __init__( + self, schema, subscription_manager: AsyncSubscriptionManager | None = None + ): super().__init__(schema, subscription_manager) @staticmethod async def on_open(connection_context): await connection_context.open() - async def handle(self, websocket: WebSocket, request_context=None, - protocol: ProtocolEnum = ProtocolEnum.GRAPHQL_WS): + async def handle( + self, + websocket: WebSocket, + request_context=None, + protocol: ProtocolEnum = ProtocolEnum.GRAPHQL_WS, + ): - connection_context = FastAPIConnectionContext(websocket=websocket, request_context=request_context, - protocol=protocol) + connection_context = FastAPIConnectionContext( + websocket=websocket, request_context=request_context, protocol=protocol + ) request_context.connection_context = connection_context if protocol != ProtocolEnum(websocket.headers.get("sec-websocket-protocol")): await connection_context.close(1000) @@ -29,9 +36,14 @@ async def handle(self, websocket: WebSocket, request_context=None, await self.on_open(connection_context) try: while not connection_context.closed: - message: ClientToServerMessage = MessageParser(protocol=protocol).parse_client_message( - await connection_context.receive()) - await self.on_message(protocol=protocol, connection_context=connection_context, message=message) + message: ClientToServerMessage = MessageParser( + protocol=protocol + ).parse_client_message(await connection_context.receive()) + await self.on_message( + protocol=protocol, + connection_context=connection_context, + message=message, + ) except ConnectionClosedException: pass else: diff --git a/graphql_ws/protocols/graphql_transport_ws/message_types/connection_init.py b/graphql_ws/protocols/graphql_transport_ws/message_types/connection_init.py index 8c76b74..a1606c1 100644 --- a/graphql_ws/protocols/graphql_transport_ws/message_types/connection_init.py +++ b/graphql_ws/protocols/graphql_transport_ws/message_types/connection_init.py @@ -1,4 +1,3 @@ - class ConnectionInitGraphQLTransportWSMessage: """ Direction: Client -> Server diff --git a/graphql_ws/protocols/graphql_transport_ws/message_types/invalid.py b/graphql_ws/protocols/graphql_transport_ws/message_types/invalid.py index 870791c..a4c683b 100644 --- a/graphql_ws/protocols/graphql_transport_ws/message_types/invalid.py +++ b/graphql_ws/protocols/graphql_transport_ws/message_types/invalid.py @@ -13,5 +13,7 @@ class InvalidGraphQLTransportWSMessage(BiDirectionalMessage): does not constitute an error. It is permissable to simply ignore all unknown IDs without closing the connection. """ - def __init__(self, ): + def __init__( + self, + ): pass diff --git a/graphql_ws/protocols/graphql_transport_ws/message_types/subscribe.py b/graphql_ws/protocols/graphql_transport_ws/message_types/subscribe.py index 06fb1d1..96ae7f0 100644 --- a/graphql_ws/protocols/graphql_transport_ws/message_types/subscribe.py +++ b/graphql_ws/protocols/graphql_transport_ws/message_types/subscribe.py @@ -2,7 +2,9 @@ class SubscribeMessagePayload(object): - def __init__(self, operation_name: str | None, query: str, variables: dict, extensions: dict): + def __init__( + self, operation_name: str | None, query: str, variables: dict, extensions: dict + ): self.operation_name = operation_name self.query = query self.variables = variables diff --git a/graphql_ws/protocols/messages/bi_directional.py b/graphql_ws/protocols/messages/bi_directional.py index f5e20b3..4424861 100644 --- a/graphql_ws/protocols/messages/bi_directional.py +++ b/graphql_ws/protocols/messages/bi_directional.py @@ -1,2 +1,2 @@ -class BiDirectionalMessage(object): +class BiDirectionalMessage: pass diff --git a/graphql_ws/protocols/messages/client_to_server.py b/graphql_ws/protocols/messages/client_to_server.py index 06b4528..461a82c 100644 --- a/graphql_ws/protocols/messages/client_to_server.py +++ b/graphql_ws/protocols/messages/client_to_server.py @@ -1,5 +1,4 @@ -class ClientToServerMessage(object): - +class ClientToServerMessage: def __init__(self, type: str, payload: dict | None = None, id_: str | None = None): self.id = id_ self.type = type diff --git a/graphql_ws/protocols/messages/message_parser.py b/graphql_ws/protocols/messages/message_parser.py index 02c64dc..0f3834e 100644 --- a/graphql_ws/protocols/messages/message_parser.py +++ b/graphql_ws/protocols/messages/message_parser.py @@ -1,8 +1,11 @@ import json from graphql_ws.protocols.exceptions import UnSupportedProtocolException -from graphql_ws.protocols.graphql_ws.message_types import ConnectionInitGraphQLWSMessage, StartGraphQLWSMessage, \ - ConnectionTerminateGraphQLWSMessage +from graphql_ws.protocols.graphql_ws.message_types import ( + ConnectionInitGraphQLWSMessage, + StartGraphQLWSMessage, + ConnectionTerminateGraphQLWSMessage, +) from graphql_ws.protocols.messages import ClientToServerMessage from graphql_ws.protocols.messages.exceptions import ClientToServerMessageInvalid from graphql_ws.protocols.protocol import ProtocolEnum @@ -15,17 +18,21 @@ def __init__(self, protocol: ProtocolEnum): def parse_client_message(self, message: str) -> ClientToServerMessage: message = json.loads(message) if self.protocol == ProtocolEnum.GRAPHQL_WS: - match message.get('type'): + match message.get("type"): case ConnectionInitGraphQLWSMessage.type: - return ConnectionInitGraphQLWSMessage(payload=message.get('payload')) + return ConnectionInitGraphQLWSMessage( + payload=message.get("payload") + ) case StartGraphQLWSMessage.type: - return StartGraphQLWSMessage(id_=message.get('id'), payload=message.get('payload')) + return StartGraphQLWSMessage( + id_=message.get("id"), payload=message.get("payload") + ) case ConnectionTerminateGraphQLWSMessage.type: return ConnectionTerminateGraphQLWSMessage() case _: raise ClientToServerMessageInvalid() elif self.protocol == ProtocolEnum.GRAPHQL_TRANSPORT_WS: - match message.get('type'): + match message.get("type"): case 1: print("This is case 1") case 2: diff --git a/graphql_ws/servers/async_server.py b/graphql_ws/servers/async_server.py index d6b2d94..c889bcb 100644 --- a/graphql_ws/servers/async_server.py +++ b/graphql_ws/servers/async_server.py @@ -3,11 +3,23 @@ from .server import BaseSubscriptionServer from ..contexts import AsyncConnectionContext, BaseConnectionContext from ..protocols import ProtocolEnum -from ..protocols.graphql_transport_ws.message_types import SubscribeGraphQLTransportWSMessage -from ..protocols.graphql_ws.message_types import ConnectionInitGraphQLWSMessage, ConnectionAckGraphQLWSMessage, \ - StartGraphQLWSMessage, DataGraphQLWSMessage, \ - CompleteGraphQLWSMessage, ConnectionTerminateGraphQLWSMessage, StopGraphQLWSMessage -from ..protocols.messages import ClientToServerMessage, BiDirectionalMessage, ServerToClientMessage +from ..protocols.graphql_transport_ws.message_types import ( + SubscribeGraphQLTransportWSMessage, +) +from ..protocols.graphql_ws.message_types import ( + ConnectionInitGraphQLWSMessage, + ConnectionAckGraphQLWSMessage, + StartGraphQLWSMessage, + DataGraphQLWSMessage, + CompleteGraphQLWSMessage, + ConnectionTerminateGraphQLWSMessage, + StopGraphQLWSMessage, +) +from ..protocols.messages import ( + ClientToServerMessage, + BiDirectionalMessage, + ServerToClientMessage, +) from ..protocols.messages.exceptions import ClientToServerOrBiDirectionalRequired from ..subscription_managers import AsyncSubscriptionManager from ..subscription_managers.exceptions import SubscriberAlreadyExistException @@ -27,18 +39,28 @@ async def execute(self, params): context_value=params.get("context_value"), ) - async def on_message(self, protocol: ProtocolEnum, connection_context: AsyncConnectionContext, - message: ClientToServerMessage | BiDirectionalMessage): + async def on_message( + self, + protocol: ProtocolEnum, + connection_context: AsyncConnectionContext, + message: ClientToServerMessage | BiDirectionalMessage, + ): try: - if issubclass(type(message), ClientToServerMessage) or issubclass(type(message), BiDirectionalMessage): + if issubclass(type(message), ClientToServerMessage) or issubclass( + type(message), BiDirectionalMessage + ): await self.process_message(protocol, connection_context, message) else: raise ClientToServerOrBiDirectionalRequired() except ClientToServerOrBiDirectionalRequired: return self.send_error(connection_context, None, e) - async def process_message(self, protocol: ProtocolEnum, connection_context: AsyncConnectionContext, - message: ClientToServerMessage | BiDirectionalMessage): + async def process_message( + self, + protocol: ProtocolEnum, + connection_context: AsyncConnectionContext, + message: ClientToServerMessage | BiDirectionalMessage, + ): if protocol.GRAPHQL_WS: if isinstance(message, ConnectionInitGraphQLWSMessage): await self.on_connection_init(connection_context, message) @@ -48,15 +70,20 @@ async def process_message(self, protocol: ProtocolEnum, connection_context: Asyn elif isinstance(message, ConnectionTerminateGraphQLWSMessage): await self.on_terminate(connection_context, message) elif isinstance(message, StopGraphQLWSMessage): - await self.subscription_manager.unsubscribe(connection_context=connection_context) + await self.subscription_manager.unsubscribe( + connection_context=connection_context + ) await self.on_terminate(connection_context, message) elif protocol.GRAPHQL_TRANSPORT_WS: pass else: pass - async def on_start(self, connection_context: AsyncConnectionContext, - message: StartGraphQLWSMessage | SubscribeGraphQLTransportWSMessage): + async def on_start( + self, + connection_context: AsyncConnectionContext, + message: StartGraphQLWSMessage | SubscribeGraphQLTransportWSMessage, + ): params = self.get_graphql_params(connection_context, message.payload) execution_result = await self.schema.subscribe( query=params.get("request_string"), @@ -68,49 +95,86 @@ async def on_start(self, connection_context: AsyncConnectionContext, if hasattr(execution_result, "__aiter__"): iterator = execution_result.__aiter__() async for result in iterator: - await self.send_message(connection_context, - message=DataGraphQLWSMessage(id_=message.id, payload=result.formatted)) + await self.send_message( + connection_context, + message=DataGraphQLWSMessage( + id_=message.id, payload=result.formatted + ), + ) else: if is_awaitable(execution_result): execution_result = await execution_result - await self.send_message(connection_context, - message=DataGraphQLWSMessage(id_=message.id, - payload=execution_result.formatted)) + await self.send_message( + connection_context, + message=DataGraphQLWSMessage( + id_=message.id, payload=execution_result.formatted + ), + ) except SubscriberAlreadyExistException as error: pass # await self.send_error(connection_context, error) else: - await self.on_complete(connection_context=connection_context, message=CompleteGraphQLWSMessage( - id_=message.id - )) + await self.on_complete( + connection_context=connection_context, + message=CompleteGraphQLWSMessage(id_=message.id), + ) - async def on_complete(self, connection_context: AsyncConnectionContext, message: CompleteGraphQLWSMessage): - await self.send_message(connection_context, - message=message) - await self.subscription_manager.unsubscribe(connection_context=connection_context) + async def on_complete( + self, + connection_context: AsyncConnectionContext, + message: CompleteGraphQLWSMessage, + ): + await self.send_message(connection_context, message=message) + await self.subscription_manager.unsubscribe( + connection_context=connection_context + ) await connection_context.close(1000) - async def on_terminate(self, connection_context: AsyncConnectionContext, - message: ConnectionTerminateGraphQLWSMessage): - await self.subscription_manager.unsubscribe(connection_context=connection_context) + async def on_terminate( + self, + connection_context: AsyncConnectionContext, + message: ConnectionTerminateGraphQLWSMessage, + ): + await self.subscription_manager.unsubscribe( + connection_context=connection_context + ) await connection_context.close(1000) - async def on_stop(self, connection_context: AsyncConnectionContext, message: StopGraphQLWSMessage): - await self.subscription_manager.unsubscribe(connection_context=connection_context) + async def on_stop( + self, connection_context: AsyncConnectionContext, message: StopGraphQLWSMessage + ): + await self.subscription_manager.unsubscribe( + connection_context=connection_context + ) await connection_context.close(1000) - async def on_connect(self, connection_context, message: ConnectionInitGraphQLWSMessage): + async def on_connect( + self, connection_context, message: ConnectionInitGraphQLWSMessage + ): pass - async def send_error(self, connection_context: AsyncConnectionContext, exception: SubscriberAlreadyExistException): + async def send_error( + self, + connection_context: AsyncConnectionContext, + exception: SubscriberAlreadyExistException, + ): if not connection_context.closed: await connection_context.close(exception.code) - async def on_connection_init(self, connection_context: AsyncConnectionContext, - message: ConnectionInitGraphQLWSMessage): + async def on_connection_init( + self, + connection_context: AsyncConnectionContext, + message: ConnectionInitGraphQLWSMessage, + ): await self.on_connect(connection_context, message) - await self.send_message(connection_context, message=ConnectionAckGraphQLWSMessage(payload=message.payload)) + await self.send_message( + connection_context, + message=ConnectionAckGraphQLWSMessage(payload=message.payload), + ) - async def send_message(self, connection_context: AsyncConnectionContext, - message: ServerToClientMessage | BiDirectionalMessage): + async def send_message( + self, + connection_context: AsyncConnectionContext, + message: ServerToClientMessage | BiDirectionalMessage, + ): await connection_context.send(data=message.data) diff --git a/graphql_ws/servers/server.py b/graphql_ws/servers/server.py index 37d89bd..7c5b0dd 100644 --- a/graphql_ws/servers/server.py +++ b/graphql_ws/servers/server.py @@ -3,41 +3,69 @@ from graphql_ws.contexts import BaseConnectionContext from graphql_ws.protocols import ProtocolEnum -from graphql_ws.protocols.graphql_transport_ws.message_types import SubscribeGraphQLTransportWSMessage -from graphql_ws.protocols.graphql_ws.message_types import StartGraphQLWSMessage, CompleteGraphQLWSMessage, \ - ConnectionInitGraphQLWSMessage, ConnectionTerminateGraphQLWSMessage, StopGraphQLWSMessage -from graphql_ws.protocols.messages import ClientToServerMessage, BiDirectionalMessage, ServerToClientMessage +from graphql_ws.protocols.graphql_transport_ws.message_types import ( + SubscribeGraphQLTransportWSMessage, +) +from graphql_ws.protocols.graphql_ws.message_types import ( + StartGraphQLWSMessage, + CompleteGraphQLWSMessage, + ConnectionInitGraphQLWSMessage, + ConnectionTerminateGraphQLWSMessage, + StopGraphQLWSMessage, +) +from graphql_ws.protocols.messages import ( + ClientToServerMessage, + BiDirectionalMessage, + ServerToClientMessage, +) from graphql_ws.subscription_managers import BaseSubscriptionManager class BaseSubscriptionServer(abc.ABC): - def __init__(self, schema, subscription_manager: BaseSubscriptionManager | None = None, keep_alive=True): + def __init__( + self, + schema, + subscription_manager: BaseSubscriptionManager | None = None, + keep_alive=True, + ): self.schema = schema self.keep_alive = keep_alive self.subscription_manager = subscription_manager @abstractmethod - def send_error(self, connection_context: BaseConnectionContext, exception: Exception): + def send_error( + self, connection_context: BaseConnectionContext, exception: Exception + ): pass @abstractmethod - def on_message(self, protocol: ProtocolEnum, connection_context: BaseConnectionContext, - message: ClientToServerMessage | BiDirectionalMessage): + def on_message( + self, + protocol: ProtocolEnum, + connection_context: BaseConnectionContext, + message: ClientToServerMessage | BiDirectionalMessage, + ): pass @staticmethod - def get_graphql_params(connection_context: BaseConnectionContext, payload: dict) -> dict: + def get_graphql_params( + connection_context: BaseConnectionContext, payload: dict + ) -> dict: context = payload.get("context", connection_context.request_context) return { "request_string": payload.get("query"), "variable_values": payload.get("variables"), "operation_name": payload.get("operationName"), - "context_value": context + "context_value": context, } @abstractmethod - def process_message(self, protocol: ProtocolEnum, connection_context: BaseConnectionContext, - message: ClientToServerMessage | BiDirectionalMessage): + def process_message( + self, + protocol: ProtocolEnum, + connection_context: BaseConnectionContext, + message: ClientToServerMessage | BiDirectionalMessage, + ): pass @abstractmethod @@ -45,27 +73,47 @@ def on_connect(self, connection_context, payload): pass @abstractmethod - def on_complete(self, connection_context: BaseConnectionContext, message: CompleteGraphQLWSMessage): + def on_complete( + self, + connection_context: BaseConnectionContext, + message: CompleteGraphQLWSMessage, + ): pass @abstractmethod - def on_terminate(self, connection_context: BaseConnectionContext, message: ConnectionTerminateGraphQLWSMessage): + def on_terminate( + self, + connection_context: BaseConnectionContext, + message: ConnectionTerminateGraphQLWSMessage, + ): pass @abstractmethod - def on_stop(self, connection_context: BaseConnectionContext, message: StopGraphQLWSMessage): + def on_stop( + self, connection_context: BaseConnectionContext, message: StopGraphQLWSMessage + ): pass @abstractmethod - def on_connection_init(self, connection_context: BaseConnectionContext, message: ConnectionInitGraphQLWSMessage): + def on_connection_init( + self, + connection_context: BaseConnectionContext, + message: ConnectionInitGraphQLWSMessage, + ): pass @abstractmethod - def send_message(self, connection_context: BaseConnectionContext, - message: ServerToClientMessage | BiDirectionalMessage): + def send_message( + self, + connection_context: BaseConnectionContext, + message: ServerToClientMessage | BiDirectionalMessage, + ): pass @abstractmethod - def on_start(self, connection_context: BaseConnectionContext, - message: StartGraphQLWSMessage | SubscribeGraphQLTransportWSMessage): + def on_start( + self, + connection_context: BaseConnectionContext, + message: StartGraphQLWSMessage | SubscribeGraphQLTransportWSMessage, + ): pass diff --git a/graphql_ws/subscription_managers/async_subscription_manager.py b/graphql_ws/subscription_managers/async_subscription_manager.py index ae15dfb..c40e427 100644 --- a/graphql_ws/subscription_managers/async_subscription_manager.py +++ b/graphql_ws/subscription_managers/async_subscription_manager.py @@ -20,21 +20,25 @@ async def subscribe(self, topic: str, connection_context: AsyncConnectionContext # if connection_context.id in self.subcribers: # raise SubscriberAlreadyExistException(unique_operation_id=connection_context.id) if connection_context.id not in self.subcribers: - self.subcribers.append( - connection_context.id - ) - connection_context_ = next((context for context in self.connection_contexts if context["topic"] == topic), None) + self.subcribers.append(connection_context.id) + connection_context_ = next( + ( + context + for context in self.connection_contexts + if context["topic"] == topic + ), + None, + ) if not connection_context_: - connection_context_ = { - "topic": topic, - "subscribers": [] - } + connection_context_ = {"topic": topic, "subscribers": []} self.connection_contexts.append(connection_context_) - connection_context_['subscribers'].append({ - "id": connection_context.id, - "connection_context": connection_context, - "queue": queue - }) + connection_context_["subscribers"].append( + { + "id": connection_context.id, + "connection_context": connection_context, + "queue": queue, + } + ) while True: item = await queue.get() yield item @@ -46,16 +50,22 @@ async def create_topic(self, name: str): async def publish(self, topic: str, payload: dict): topic = await self.create_topic(topic) - connection_context_ = next((context for context in self.connection_contexts if context["topic"] == topic), None) + connection_context_ = next( + ( + context + for context in self.connection_contexts + if context["topic"] == topic + ), + None, + ) if not connection_context_: - connection_context_ = { - "topic": topic, - "subscribers": [] - } + connection_context_ = {"topic": topic, "subscribers": []} self.connection_contexts.append(connection_context_) # noinspection PyUnresolvedReferences,PyTypeChecker - tasks = [asyncio.create_task(subscriber['queue'].put(payload)) for subscriber in - connection_context_['subscribers']] + tasks = [ + asyncio.create_task(subscriber["queue"].put(payload)) + for subscriber in connection_context_["subscribers"] + ] await asyncio.gather(*tasks) async def unsubscribe(self, connection_context: AsyncConnectionContext): diff --git a/graphql_ws/subscription_managers/integrations/google_pubsub.py b/graphql_ws/subscription_managers/integrations/google_pubsub.py index dd0fb38..3f415b3 100644 --- a/graphql_ws/subscription_managers/integrations/google_pubsub.py +++ b/graphql_ws/subscription_managers/integrations/google_pubsub.py @@ -15,9 +15,8 @@ def __init__(self): @staticmethod def topic_name(topic: str): - return 'projects/{project_id}/topics/{topic}'.format( - project_id=os.getenv('GOOGLE_CLOUD_PROJECT'), - topic=topic + return "projects/{project_id}/topics/{topic}".format( + project_id=os.getenv("GOOGLE_CLOUD_PROJECT"), topic=topic ) async def subscribe(self): @@ -31,7 +30,9 @@ async def publish(self, topic: str, payload: FormattedExecutionResult): if topic not in self.topics: self.create_topic(name=topic) # noinspection PyArgumentList - await self.publisher.publish(topic=topic, payload=json.dumps(payload).encode('utf-8')) + await self.publisher.publish( + topic=topic, payload=json.dumps(payload).encode("utf-8") + ) def unsubscribe(self): pass diff --git a/graphql_ws/subscription_managers/subscription_manager.py b/graphql_ws/subscription_managers/subscription_manager.py index e6ae7a9..d3f71b1 100644 --- a/graphql_ws/subscription_managers/subscription_manager.py +++ b/graphql_ws/subscription_managers/subscription_manager.py @@ -7,8 +7,12 @@ class BaseSubscriptionManager(abc.ABC): - def __init__(self, topics: list[str] | None = None, subscribers: list[str] | None = None, - connection_contexts: list[dict] | None = None): + def __init__( + self, + topics: list[str] | None = None, + subscribers: list[str] | None = None, + connection_contexts: list[dict] | None = None, + ): if topics is None: topics = [] if subscribers is None: diff --git a/graphql_ws/subscription_managers/sync_subscription_manager.py b/graphql_ws/subscription_managers/sync_subscription_manager.py index 189e25f..80bd69d 100644 --- a/graphql_ws/subscription_managers/sync_subscription_manager.py +++ b/graphql_ws/subscription_managers/sync_subscription_manager.py @@ -5,13 +5,20 @@ from graphql_ws.contexts import SyncConnectionContext from graphql_ws.subscription_managers import BaseSubscriptionManager -from graphql_ws.subscription_managers.exceptions import SubscriberAlreadyExistException, TopicNotFoundException +from graphql_ws.subscription_managers.exceptions import ( + SubscriberAlreadyExistException, + TopicNotFoundException, +) class SyncSubscriptionManager(BaseSubscriptionManager): - def __init__(self, topics: list[str] | None = None, subscribers: list[str] | None = None, - connection_contexts: list[dict] | None = None): + def __init__( + self, + topics: list[str] | None = None, + subscribers: list[str] | None = None, + connection_contexts: list[dict] | None = None, + ): if topics is None: topics = [] if subscribers is None: @@ -24,19 +31,20 @@ def __init__(self, topics: list[str] | None = None, subscribers: list[str] | Non super().__init__(self.topics, self.subscribers, self.connection_contexts) @abstractmethod - def subscribe(self, id_: str, topic: str, connection_context: SyncConnectionContext): + def subscribe( + self, id_: str, topic: str, connection_context: SyncConnectionContext + ): if topic not in self.topics: raise TopicNotFoundException() if id_ in self.subcribers: raise SubscriberAlreadyExistException() - self.subcribers.append( - id_ + self.subcribers.append(id_) + connection_context_ = next( + context for context in self.connection_contexts if context["topic"] == topic + ) + connection_context_["subcribers"].append( + {"id": id_, "connection_context": connection_context} ) - connection_context_ = next(context for context in self.connection_contexts if context["topic"] == topic) - connection_context_['subcribers'].append({ - "id": id_, - "connection_context": connection_context - }) @abstractmethod def create_topic(self, name: str): @@ -47,7 +55,9 @@ def create_topic(self, name: str): @abstractmethod def publish(self, topic: str, payload: FormattedExecutionResult): topic = self.create_topic(topic) - connection_context = next(context for context in self.connection_contexts if context["topic"] == topic) + connection_context = next( + context for context in self.connection_contexts if context["topic"] == topic + ) for subcriber in connection_context.get("subcribers"): subcriber.get("connection_context") diff --git a/setup.py b/setup.py index 39ddb69..02509de 100644 --- a/setup.py +++ b/setup.py @@ -15,10 +15,10 @@ def read(*rnames): ] dev_require = [ - "black==23.12.1", - "flake8==4.0.1", - "mypy==0.961", - ] + tests_require + "black==23.12.1", + "flake8==4.0.1", + "mypy==0.961", +] + tests_require setup( name="graphql-ws", @@ -33,10 +33,7 @@ def read(*rnames): url="https://github.com/graphql-python/graphql-ws", download_url=f"https://github.com/graphql-python/graphql-ws/archive/{version}.tar.gz", keywords=["graphene", "graphql", "gql", "subscription"], - install_requires=[ - "graphene>=3.1", - "graphql-core>=3.1" - ], + install_requires=["graphene>=3.1", "graphql-core>=3.1"], classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers",