diff --git a/.gitignore b/.gitignore index f2320ca3..cc9bb0a0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,8 +2,8 @@ __pycache__ fixtures/__pycache__ panels/__pycache__ lib/ +assets/mvrs/* assets/profiles/* assets/models/*/* !assets/profiles/BlenderDMX* - *.zip diff --git a/DEPENDENCIES.md b/DEPENDENCIES.md index 96580685..112b1654 100644 --- a/DEPENDENCIES.md +++ b/DEPENDENCIES.md @@ -9,3 +9,4 @@ is a list of included libraries. * 3DS Importer: https://projects.blender.org/blender/blender-addons-contrib.git * ifaddr: https://github.com/pydron/ifaddr * oscpy: https://github.com/kivy/oscpy +* zeroconf: https://github.com/python-zeroconf/python-zeroconf diff --git a/__init__.py b/__init__.py index e8f6dd77..5c9f4545 100644 --- a/__init__.py +++ b/__init__.py @@ -20,6 +20,8 @@ import json import uuid as py_uuid import re +from datetime import datetime +import pathlib from dmx.pymvr import GeneralSceneDescription from dmx.mvr import extract_mvr_textures, process_mvr_child_list @@ -41,13 +43,16 @@ from dmx.panels.programmer import * import dmx.panels.profiles as Profiles -from dmx.preferences import DMX_Preferences +from dmx.preferences import DMX_Preferences, DMX_Regenrate_UUID from dmx.group import FixtureGroup from dmx.osc_utils import DMX_OSC_Templates from dmx.osc import DMX_OSC +from dmx.mdns import DMX_Zeroconf from dmx.util import rgb_to_cmy, xyY2rgbaa, ShowMessageBox from dmx.mvr_objects import DMX_MVR_Object +from dmx.mvr_xchange import * +from dmx.mvrx_protocol import DMX_MVR_X_Protocol from bpy.props import (BoolProperty, StringProperty, @@ -55,6 +60,7 @@ FloatProperty, FloatVectorProperty, PointerProperty, + EnumProperty, CollectionProperty) from bpy.types import (PropertyGroup, @@ -85,6 +91,12 @@ class DMX_TempData(PropertyGroup): description="When selecting a group, add to existing selection", default = True) + + mvr_xchange: PointerProperty( + name = "MVR-xchange", + type=DMX_MVR_Xchange + ) + class DMX(PropertyGroup): # Base classes to be registered @@ -101,6 +113,12 @@ class DMX(PropertyGroup): DMX_Universe, DMX_Value, DMX_PT_Setup, + DMX_OP_MVR_Download, + DMX_OP_MVR_Import, + DMX_MVR_Xchange_Commit, + DMX_MVR_Xchange_Client, + DMX_MVR_Xchange, + DMX_Regenrate_UUID, DMX_Preferences) # Classes to be registered @@ -153,6 +171,9 @@ class DMX(PropertyGroup): DMX_OT_Programmer_Set_Ignore_Movement, DMX_OT_Programmer_Unset_Ignore_Movement, DMX_PT_DMX_OSC, + DMX_PT_DMX_MVR_X, + DMX_UL_MVR_Commit, + DMX_OP_MVR_Test, DMX_OT_Fixture_ForceRemove, DMX_OT_Fixture_SelectNext, DMX_OT_Fixture_SelectPrevious, @@ -385,6 +406,8 @@ def linkFile(self): # group.rebuild() self.migrations() + self.ensure_application_uuid() + # Unlink Add-on from file # This is only called when the DMX collection is externally removed @@ -407,6 +430,13 @@ def unlinkFile(self): # # Setup > Background > Color + def ensure_application_uuid(self): + addon_name = pathlib.Path(__file__).parent.parts[-1] + prefs = bpy.context.preferences.addons[addon_name].preferences + application_uuid = prefs.get("application_uuid", 0) + if application_uuid == 0: + prefs["application_uuid"] = str(py_uuid.uuid4()) # must never be 0 + def migrations(self): """Provide migration scripts when bumping the data_version""" file_data_version = 1 # default data version before we started setting it up @@ -580,7 +610,7 @@ def onSelectGeometries(self, context): def onLoggingLevel(self, context): DMX_Log.log.setLevel(self.logging_level) - logging_level: bpy.props.EnumProperty( + logging_level: EnumProperty( name= "Logging Level", description= "logging level", default = "ERROR", @@ -604,7 +634,7 @@ def onVolumePreview(self, context): # default = False, # update = onVolumePreview) - volume_preview: bpy.props.EnumProperty( + volume_preview: EnumProperty( name= "Simple beam", description= "Display 'fake' beam cone", default = "NONE", @@ -699,6 +729,34 @@ def onUniverseN(self, context): description="The network card/interface to listen for ArtNet DMX data", items = DMX_Network.cards ) + #zeroconf - mvr-xchange + + def onZeroconfEnable(self, context): + if self.zeroconf_enabled: + clients = bpy.context.window_manager.dmx.mvr_xchange.mvr_xchange_clients + clients.clear() + DMX_Zeroconf.enable() + else: + DMX_Zeroconf.disable() + + def onMVR_xchange_enable(self, context): + if self.mvrx_enabled: + clients = context.window_manager.dmx.mvr_xchange + all_clients = context.window_manager.dmx.mvr_xchange.mvr_xchange_clients + selected = clients.selected_mvr_client + selected_client = None + for selected_client in all_clients: + if selected_client.station_uuid == selected: + break + if not selected_client: + return + print(selected_client.ip_address, selected_client.station_name) + DMX_MVR_X_Protocol.enable(selected_client) + else: + DMX_MVR_X_Protocol.disable() + + + # OSC functionality @@ -760,6 +818,20 @@ def onArtNetEnable(self, context): description="Port number of the host where you want to send the OSC signal", default=42000 ) + + zeroconf_enabled : BoolProperty( + name = "Enable MVR-xchange discovery", + description="Enables MVR-xchange discovery", + default = False, + update = onZeroconfEnable + ) + + mvrx_enabled : BoolProperty( + name = "Enable MVR-xchange connection", + description="Connects to MVR-xchange client", + default = False, + update = onMVR_xchange_enable + ) # # DMX > ArtNet > Status artnet_status : EnumProperty( @@ -947,7 +1019,7 @@ def syncProgrammer(self): self.programmer_tilt = data['Tilt']/127.0-1 - fixtures_sorting_order: bpy.props.EnumProperty( + fixtures_sorting_order: EnumProperty( name= "Sort by", description= "Fixture sorting order", default = "ADDRESS", @@ -1095,6 +1167,77 @@ def ensureUniverseExists(self, universe): self.addUniverse() self.universes_n = len(self.universes) + def createMVR_Client(self, station_name, station_uuid, ip_address, port): + clients = bpy.context.window_manager.dmx.mvr_xchange.mvr_xchange_clients + for client in clients: + if client.station_uuid == station_uuid: + return # client already in the list + + client = clients.add() + client.station_name = station_name + client.station_uuid = station_uuid + now = int(datetime.now().timestamp()) + client.last_seen = now + client.ip_address = ip_address + client.port = port + + def removeMVR_Client(self, station_name, station_uuid, ip_addres, port): + clients = bpy.context.window_manager.dmx.mvr_xchange.mvr_xchange_clients + for client in clients: + if client.station_uuid == station_uuid: + clients.remove(client) + break + + def updateMVR_Client(self, station_name, station_uuid, ip_address, port): + clients = bpy.context.window_manager.dmx.mvr_xchange.mvr_xchange_clients + updated = False + for client in clients: + if client.station_uuid == station_uuid: + client.station_name = station_name + now = int(datetime.now().timestamp()) + client.last_seen = now + client.ip_address = ip_address + client.port = port + updated = True + break + if not updated: + self.createMVR_Client(station_name, station_uuid, ip_address, port) + + def createMVR_Commits(self, commits, station_uuid): + clients = bpy.context.window_manager.dmx.mvr_xchange.mvr_xchange_clients + for client in clients: + if client.station_uuid == station_uuid: + client.commits.clear() + + for commit in commits: + + if "FileName" in commit: + filename = commit["FileName"] + else: + filename = commit["Comment"] + if not len(filename): + filename = commit["FileUUID"] + + now = int(datetime.now().timestamp()) + client.last_seen = now + new_commit = client.commits.add() + new_commit.station_uuid = station_uuid + new_commit.comment = commit["Comment"] + new_commit.commit_uuid = commit["FileUUID"] + new_commit.file_size = commit["FileSize"] + new_commit.file_name = filename + new_commit.timestamp = now + + def fetched_mvr_downloaded_file(self, commit): + clients = bpy.context.window_manager.dmx.mvr_xchange.mvr_xchange_clients + now = int(datetime.now().timestamp()) + for client in clients: + if client.station_uuid == commit.station_uuid: + for c_commit in client.commits: + if c_commit.commit_uuid == commit.commit_uuid: + c_commit.timestamp_saved = now + + # # Groups def createGroup(self, name): @@ -1174,6 +1317,8 @@ def render(self): fixture.render() + + # Handlers # @@ -1224,6 +1369,8 @@ def onLoadFile(scene): DMX_ArtNet.disable() DMX_sACN.disable() DMX_OSC.disable() + DMX_MVR_X_Protocol.disable() + DMX_Zeroconf.disable() @bpy.app.handlers.persistent def onUndo(scene): @@ -1305,6 +1452,8 @@ def unregister(): DMX_ArtNet.disable() DMX_sACN.disable() DMX_OSC.disable() + DMX_MVR_X_Protocol.disable() + DMX_Zeroconf.disable() try: for cls in Profiles.classes: diff --git a/async_timeout/__init__.py b/async_timeout/__init__.py new file mode 100755 index 00000000..1ffb069f --- /dev/null +++ b/async_timeout/__init__.py @@ -0,0 +1,239 @@ +import asyncio +import enum +import sys +import warnings +from types import TracebackType +from typing import Optional, Type + + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + + +if sys.version_info >= (3, 11): + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + task.uncancel() + +else: + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + pass + + +__version__ = "4.0.3" + + +__all__ = ("timeout", "timeout_at", "Timeout") + + +def timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + if delay is not None: + deadline = loop.time() + delay # type: Optional[float] + else: + deadline = None + return Timeout(deadline, loop) + + +def timeout_at(deadline: Optional[float]) -> "Timeout": + """Schedule the timeout at absolute time. + + deadline argument points on the time in the same clock system + as loop.time(). + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + + >>> async with timeout_at(loop.time() + 10): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + """ + loop = asyncio.get_running_loop() + return Timeout(deadline, loop) + + +class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + +@final +class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler", "_task") + + def __init__( + self, deadline: Optional[float], loop: asyncio.AbstractEventLoop + ) -> None: + self._loop = loop + self._state = _State.INIT + + self._task: Optional["asyncio.Task[object]"] = None + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + def __enter__(self) -> "Timeout": + warnings.warn( + "with timeout() is deprecated, use async with timeout() instead", + DeprecationWarning, + stacklevel=2, + ) + self._do_enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + self._task = None + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + + The delay can be negative. + + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError("cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + + deadline argument points on the time in the same clock system + as loop.time(). + + If new deadline is in the past the timeout is raised immediately. + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError("cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + self._task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon(self._on_timeout) + else: + self._timeout_handler = self._loop.call_at(deadline, self._on_timeout) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + assert self._task is not None + _uncancel_task(self._task) + self._timeout_handler = None + self._task = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self) -> None: + assert self._task is not None + self._task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None diff --git a/async_timeout/py.typed b/async_timeout/py.typed new file mode 100755 index 00000000..3b94f915 --- /dev/null +++ b/async_timeout/py.typed @@ -0,0 +1 @@ +Placeholder diff --git a/mdns.py b/mdns.py new file mode 100644 index 00000000..d02ae5fb --- /dev/null +++ b/mdns.py @@ -0,0 +1,72 @@ +from types import DynamicClassAttribute +import bpy +from dmx.zeroconf import ( + IPVersion, + ServiceBrowser, + ServiceStateChange, + Zeroconf, +) +from dmx.logging import DMX_Log +from typing import cast + + +class DMX_Zeroconf: + _instance = None + + def __init__(self): + super(DMX_Zeroconf, self).__init__() + self.data = None + self.zeroconf = None + self.browser = None + self._dmx = bpy.context.scene.dmx + + def callback(zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange) -> None: + #print(f"Service {name} of type {service_type} state changed: {state_change}") + + info = zeroconf.get_service_info(service_type, name) + station_name = "" + station_uuid = "" + ip_address = "" + port = 0 + + if info: + addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_scoped_addresses()] + for address in addresses: + if "::" in address: + continue + ip_a, ip_port = address.split(":") + ip_address = ip_a + port = ip_port + + if info.properties: + if b"StationName" in info.properties: + station_name = info.properties[b"StationName"].decode("utf-8") + if b"StationUUID" in info.properties: + station_uuid = info.properties[b"StationUUID"].decode("utf-") + + if state_change is ServiceStateChange.Added: + DMX_Zeroconf._instance._dmx.createMVR_Client(station_name, station_uuid, ip_address, int(port)) + elif state_change is ServiceStateChange.Updated: + DMX_Zeroconf._instance._dmx.updateMVR_Client(station_name, station_uuid, ip_address, int(port)) + else: # removed + DMX_Zeroconf._instance._dmx.removeMVR_Client(station_name, station_uuid, ip_address, int(port)) + + @staticmethod + def enable(): + if DMX_Zeroconf._instance: + return + DMX_Zeroconf._instance = DMX_Zeroconf() + + services = ["_mvrxchange._tcp.local."] + DMX_Zeroconf._instance.zeroconf = Zeroconf(ip_version=IPVersion.V4Only) + DMX_Zeroconf._instance.browser = ServiceBrowser(DMX_Zeroconf._instance.zeroconf, services, handlers=[DMX_Zeroconf.callback]) + DMX_Log.log.info("Enabling Zeroconf") + print("starting mvrx discovery") + + @staticmethod + def disable(): + if DMX_Zeroconf._instance: + DMX_Zeroconf._instance.zeroconf.close() + DMX_Zeroconf._instance = None + print("closing mvrx discovery") + DMX_Log.log.info("Disabling Zeroconf") diff --git a/mvr_xchange.py b/mvr_xchange.py new file mode 100644 index 00000000..4002d580 --- /dev/null +++ b/mvr_xchange.py @@ -0,0 +1,46 @@ +import bpy +from bpy.props import BoolProperty, CollectionProperty, EnumProperty, IntProperty, StringProperty +from bpy.types import PropertyGroup + + +class DMX_MVR_Xchange_Commit(PropertyGroup): + commit_uuid: StringProperty(name="File UUID") + comment: StringProperty(name="Comment") + file_name: StringProperty(name="File Name") + station_uuid: StringProperty(name="Station UUID") + file_size: IntProperty(name="File Size") + timestamp: IntProperty(name="Time of info") + timestamp_saved: IntProperty(name="Time of saving") + subscribed: BoolProperty(name="Subscribed to") + + +class DMX_MVR_Xchange_Client(PropertyGroup): + ip_address: StringProperty(name="IP Address") + port: IntProperty(name="Port") + subscribed: BoolProperty(name="Subscribed to") + last_seen: IntProperty(name="Last Seen Time") + station_name: StringProperty(name="Station Name") + station_uuid: StringProperty(name="Station UUID") + provider: StringProperty(name="Provider") + commits: CollectionProperty(name="Commits", type=DMX_MVR_Xchange_Commit) + + def get_clients(self, context): + #print(self, context) + clients = bpy.context.window_manager.dmx.mvr_xchange.mvr_xchange_clients + data = [] + for client in clients: + data.append((client.station_uuid, client.station_name, client.station_uuid)) + return data + +class DMX_MVR_Xchange(PropertyGroup): + selected_commit: IntProperty(default=0) + mvr_xchange_clients: CollectionProperty( + name = "MVR-xchange Clients", + type=DMX_MVR_Xchange_Client + ) + + selected_mvr_client: EnumProperty( + name = "Client", + description="", + items = DMX_MVR_Xchange_Client.get_clients + ) diff --git a/mvrx_protocol.py b/mvrx_protocol.py new file mode 100644 index 00000000..2651523d --- /dev/null +++ b/mvrx_protocol.py @@ -0,0 +1,76 @@ +import bpy +import dmx.mvrxchange_protocol as mvrx_protocol +from dmx.logging import DMX_Log +import os +import time +import pathlib +from dmx import bl_info as application_info +import uuid as py_uuid + + +class DMX_MVR_X_Protocol: + _instance = None + + def __init__(self): + super(DMX_MVR_X_Protocol, self).__init__() + self._dmx = bpy.context.scene.dmx + self.client = None + + addon_name = pathlib.Path(__file__).parent.parts[-1] + prefs = bpy.context.preferences.addons[addon_name].preferences + application_uuid = prefs.get("application_uuid", str(py_uuid.uuid4())) # must never be 0 + self.application_uuid = application_uuid + # print("bl info", application_info) # TODO: use this in the future + + def callback(data): + if "StationUUID" not in data: + print("Bad response", data) + return + uuid = data["StationUUID"] + if "Files" in data: + DMX_MVR_X_Protocol._instance._dmx.createMVR_Commits(data["Files"], uuid) + + if "file_downloaded" in data: + DMX_MVR_X_Protocol._instance._dmx.fetched_mvr_downloaded_file(data["file_downloaded"]) + + @staticmethod + def request_file(commit): + if DMX_MVR_X_Protocol._instance: + if DMX_MVR_X_Protocol._instance.client: + ADDON_PATH = os.path.dirname(os.path.abspath(__file__)) + path = os.path.join(ADDON_PATH, "assets", "mvrs", f"{commit.commit_uuid}.mvr") + try: + DMX_MVR_X_Protocol._instance.client.request_file(commit, path) + except: + print("problem requesting file") + return + DMX_Log.log.info("Requesting file") + + @staticmethod + def enable(client): + if DMX_MVR_X_Protocol._instance: + return + DMX_MVR_X_Protocol._instance = DMX_MVR_X_Protocol() + print("Connecting to MVR-xchange client", client.ip_address, client.port) + try: + DMX_MVR_X_Protocol._instance.client = mvrx_protocol.client( + client.ip_address, client.port, timeout=0, + callback=DMX_MVR_X_Protocol.callback, + application_uuid=DMX_MVR_X_Protocol._instance.application_uuid) + + except Exception as e: + print("Cannot connect to host", e) + return + DMX_MVR_X_Protocol._instance.client.start() + DMX_MVR_X_Protocol._instance.client.join_mvr() + DMX_Log.log.info("Joining") + + @staticmethod + def disable(): + if DMX_MVR_X_Protocol._instance: + if DMX_MVR_X_Protocol._instance.client: + DMX_MVR_X_Protocol._instance.client.leave_mvr() + time.sleep(0.3) + DMX_MVR_X_Protocol._instance.client.stop() + DMX_MVR_X_Protocol._instance = None + DMX_Log.log.info("Disabling MVR") diff --git a/mvrxchange_protocol/__init__.py b/mvrxchange_protocol/__init__.py new file mode 100755 index 00000000..49a6622e --- /dev/null +++ b/mvrxchange_protocol/__init__.py @@ -0,0 +1,158 @@ +#!/bin/env python3 + +import socket +import json +from threading import Thread +import struct +from queue import Queue +import time + +# A very rudimentary MVR-xchange client +# The socket handling is lame at this point, reconnecting socket on every send (which is not often though) + + +class client(Thread): + def __init__(self, ip_address, port, callback, timeout=None, application_uuid=0): + Thread.__init__(self) + self.callback = callback + self.running = True + self.queue = Queue() + self.ip_address = ip_address + self.application_uuid = application_uuid + self.port = port + self.filepath = "" + self.commit = "" + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.connect((ip_address, port)) + if timeout is not None and self.socket is not None: + self.socket.settimeout(timeout) + + def reconnect(self): + print("reconnecting") + self.socket.close() + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.connect((self.ip_address, self.port)) + + def join_mvr(self): + self.send(self.join_message()) + + def leave_mvr(self): + self.send(self.leave_message()) + + def request_file(self, commit, path): + self.filepath = path + self.commit = commit + self.send(self.request_message(commit.commit_uuid)) + + def stop(self): + self.running = False + if self.socket is not None: + self.socket.close() + self.join() + + def send(self, message): + self.queue.put(message) + + def run(self): + data = b"" + if self.socket is None: + return + while self.running: + message = None + try: + new_data = self.socket.recv(1024) + data += new_data + except Exception as e: + ... + else: + if data: + header = self.parse_header(data) + if header["error"]: + data = b"" + continue + + if len(data) >= header["total_len"]: + self.parse_data(data, self.callback) + data = b"" + + if not self.queue.empty(): + message = self.queue.get() + if message: + self.reconnect() + self.socket.sendall(message) + + if not data: + time.sleep(0.2) + + def parse_header(self, data): + header = {"error": True} + if len(data) > 4: + if struct.unpack("!l", data[0:4])[0] == 778682: + msg_version = struct.unpack("!l", data[4:8])[0] + msg_number = struct.unpack("!l", data[8:12])[0] + msg_count = struct.unpack("!l", data[12:16])[0] + msg_type = struct.unpack("!l", data[16:20])[0] + msg_len = struct.unpack("!q", data[20:28])[0] + header = {"version": msg_version, "number": msg_number, "count": msg_count, "type": msg_type, "data_len": msg_len, "total_len": msg_len + 28, "error": False} + return header + + def parse_data(self, data, callback): + header = self.parse_header(data) + if header["type"] == 0: # json + json_data = json.loads(data[28:].decode("utf-8")) + callback(json_data) + else: # file + with open(self.filepath, "bw") as f: + f.write(data[28:]) + callback({"file_downloaded": self.commit, "StationUUID": self.commit.station_uuid}) + + def join_message(self): + return self.craft_packet( + { + "Type": "MVR_JOIN", + "Provider": "Blender DMX", + "verMajor": 1, + "verMinor": 6, + "StationUUID": self.application_uuid, + "StationName": "Blender DMX Station", + "Files": [], + } + ) + + def leave_message(self): + return self.craft_packet( + { + "Type": "MVR_LEAVE", + "FromStationUUID": self.application_uuid, + } + ) + + def request_message(self, uuid): + return self.craft_packet( + { + "Type": "MVR_REQUEST", + "FileUUID": f"{uuid}", + "FromStationUUID": self.application_uuid, + } + ) + + def craft_packet(self, message): + MVR_PACKAGE_HEADER = 778682 + MVR_PACKAGE_VERSION = 1 + MVR_PACKAGE_NUMBER = 0 + MVR_PACKAGE_COUNT = 1 + MVR_PACKAGE_TYPE = 0 + MVR_PAYLOAD_BUFFER = json.dumps(message).encode("utf-8") + MVR_PAYLOAD_LENGTH = len(MVR_PAYLOAD_BUFFER) + + output = ( + struct.pack("!l", MVR_PACKAGE_HEADER) + + struct.pack("!l", MVR_PACKAGE_VERSION) + + struct.pack("!l", MVR_PACKAGE_NUMBER) + + struct.pack("!l", MVR_PACKAGE_COUNT) + + struct.pack("!l", MVR_PACKAGE_TYPE) + + struct.pack("!q", MVR_PAYLOAD_LENGTH) + + MVR_PAYLOAD_BUFFER + ) + return output diff --git a/panels/dmx.py b/panels/dmx.py index a1947e6f..8cbddeea 100644 --- a/panels/dmx.py +++ b/panels/dmx.py @@ -8,10 +8,13 @@ # import bpy +import os +from datetime import datetime from dmx.data import DMX_Data from bpy.props import (PointerProperty, EnumProperty, - StringProperty) + StringProperty, + IntProperty) from bpy.types import (Panel, Menu, @@ -23,9 +26,91 @@ from dmx.util import getSceneRect from dmx.osc import DMX_OSC from dmx.osc_utils import DMX_OSC_Templates +from dmx.mvrx_protocol import DMX_MVR_X_Protocol +class DMX_OP_MVR_Test(Operator): + bl_label = "Test" + bl_description = "Test operator" + bl_idname = "dmx.mvr_test" + bl_options = {"UNDO"} + def execute(self, context): + print("Test") + DMX_MVR_X_Protocol._instance.client.join_mvr() + + return {"FINISHED"} + + + +class DMX_OP_MVR_Import(Operator): + bl_label = "Import" + bl_description = "Import commit" + bl_idname = "dmx.mvr_import" + bl_options = {"UNDO"} + + uuid: StringProperty() + + def execute(self, context): + scene = context.scene + dmx = scene.dmx + ADDON_PATH = os.path.dirname(os.path.abspath(__file__)) + clients = context.window_manager.dmx.mvr_xchange + all_clients = context.window_manager.dmx.mvr_xchange.mvr_xchange_clients + selected = clients.selected_mvr_client + for client in all_clients: + if client.station_uuid == selected: + break + for commit in client.commits: + if commit.commit_uuid == self.uuid: + print("import", commit) + path = os.path.join(ADDON_PATH, "..", "assets", "mvrs", f"{commit.commit_uuid}.mvr") + print(path) + dmx.addMVR(path) + break + return {"FINISHED"} + +class DMX_OP_MVR_Download(Operator): + bl_label = "Download" + bl_description = "Download commit" + bl_idname = "dmx.mvr_download" + bl_options = {"UNDO"} + + uuid: StringProperty() + + def execute(self, context): + print("downloading") + + clients = context.window_manager.dmx.mvr_xchange + all_clients = clients.mvr_xchange_clients + selected = clients.selected_mvr_client + for client in all_clients: + if client.station_uuid == selected: + break + print("got client", client.station_name) + for commit in client.commits: + print(commit.commit_uuid) + if commit.commit_uuid == self.uuid: + print("downloading", commit) + DMX_MVR_X_Protocol.request_file(commit) + break + + return {"FINISHED"} + +class DMX_UL_MVR_Commit(UIList): + def draw_item(self, context, layout, data, item, icon, active_data, active_propname, index): + scene = context.scene + dmx = scene.dmx + icon = "GROUP_VERTEX" + #layout.context_pointer_set("mvr_xchange_clients", item) + col = layout.column() + col.label(text = f"{item.comment}", icon="CHECKBOX_HLT" if item.timestamp_saved else "CHECKBOX_DEHLT") + col = layout.column() + col.operator("dmx.mvr_download", text="", icon="IMPORT").uuid = item.commit_uuid + col.enabled = dmx.mvrx_enabled + col = layout.column() + col.operator("dmx.mvr_import", text="", icon="CHECKBOX_HLT").uuid = item.commit_uuid + col.enabled = item.timestamp_saved > 0 # List # @@ -81,6 +166,56 @@ def draw(self, context): row.prop(dmx, "osc_target_port") row.enabled = not dmx.osc_enabled +class DMX_PT_DMX_MVR_X(Panel): + bl_label = "MVR-xchange" + bl_idname = "DMX_PT_DMX_MVR_Xchange" + bl_parent_id = "DMX_PT_DMX" + bl_space_type = "VIEW_3D" + bl_region_type = "UI" + bl_category = "DMX" + bl_context = "objectmode" + bl_options = {'DEFAULT_CLOSED'} + + def draw(self, context): + layout = self.layout + dmx = context.scene.dmx + + row = layout.row() + row.prop(dmx, "zeroconf_enabled") + row = layout.row() + + clients = context.window_manager.dmx.mvr_xchange + all_clients = clients.mvr_xchange_clients + if not all_clients: + selected = None + else: + selected = clients.selected_mvr_client + + client = None + for client in all_clients: + if client.station_uuid ==selected: + break + + row.prop(clients, "selected_mvr_client", text="") + row.enabled = not dmx.mvrx_enabled + row = layout.row() + row.prop(dmx, "mvrx_enabled") + row.enabled = len(all_clients) > 0 + if not client: + return + row = layout.row() + row.label(text = f"{client.station_name}", icon = "LINKED" if dmx.mvrx_enabled else "UNLINKED") + #row.operator("dmx.mvr_test", text="Test", icon="IMPORT") + layout.template_list( + "DMX_UL_MVR_Commit", + "", + client, + "commits", + clients, + "selected_commit", + rows=4, + ) + class DMX_PT_DMX_Universes(Panel): bl_label = "Universes" bl_idname = "DMX_PT_DMX_Universes" diff --git a/preferences/__init__.py b/preferences/__init__.py index 716736c2..3f42c12b 100644 --- a/preferences/__init__.py +++ b/preferences/__init__.py @@ -1,7 +1,22 @@ import pathlib +import uuid as py_uuid -from bpy.types import AddonPreferences +import bpy from bpy.props import StringProperty +from bpy.types import AddonPreferences, Operator + + +class DMX_Regenrate_UUID(Operator): + bl_label = "Regenerate UUID" + bl_idname = "dmx.regenerate_uuid" + bl_options = {"UNDO"} + + def execute(self, context): + addon_name = pathlib.Path(__file__).parent.parts[-2] + prefs = bpy.context.preferences.addons[addon_name].preferences + uuid = str(py_uuid.uuid4()) + prefs["application_uuid"] = uuid + return {"FINISHED"} class DMX_Preferences(AddonPreferences): @@ -16,14 +31,29 @@ class DMX_Preferences(AddonPreferences): share_api_password: StringProperty( default="", name="GDTF Share Password", - subtype='PASSWORD', + subtype="PASSWORD", description="Password for GDTF Share", ) + application_uuid: StringProperty( + default=str(py_uuid.uuid4()), + name="Application UUID", + description="Used for example for MVR xchange", + ) + def draw(self, context): layout = self.layout layout.use_property_split = True layout.label(text="Username and Password for GDTF Share. Get a free account at gdtf-share.com") layout.prop(self, "share_api_username") layout.prop(self, "share_api_password") - + layout.separator() + layout.label(text="Application settings") + row = layout.row() + col = row.column() + col.prop(self, "application_uuid") + col.enabled = False + col = row.column() + col.operator("dmx.regenerate_uuid", text="", icon="FILE_REFRESH") + layout.separator() + layout.label(text="Make sure to save the preferences after editing.") diff --git a/scripts/build_release.py b/scripts/build_release.py index fc9695d0..0ca04f39 100644 --- a/scripts/build_release.py +++ b/scripts/build_release.py @@ -71,6 +71,9 @@ def read_version(): copytree("ifaddr", BUILD_DIR + "/dmx/ifaddr", ignore=ignore) copytree("oscpy", BUILD_DIR + "/dmx/oscpy", ignore=ignore) copytree("share_api_client", BUILD_DIR + "/dmx/share_api_client", ignore=ignore) +copytree("mvrxchange_protocol", BUILD_DIR + "/dmx/mvrxchange_protocol", ignore=ignore) +copytree("zeroconf", BUILD_DIR + "/dmx/zeroconf", ignore=ignore) +copytree("async_timeout", BUILD_DIR + "/dmx/async_timeout", ignore=ignore) copytree("preferences", BUILD_DIR + "/dmx/preferences", ignore=ignore) print("Copying source to build directory...") diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py new file mode 100644 index 00000000..e6b8e481 --- /dev/null +++ b/zeroconf/__init__.py @@ -0,0 +1,125 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import sys + +from ._cache import DNSCache # noqa # import needed for backwards compat +from ._core import Zeroconf +from ._dns import ( # noqa # import needed for backwards compat + DNSAddress, + DNSEntry, + DNSHinfo, + DNSNsec, + DNSPointer, + DNSQuestion, + DNSQuestionType, + DNSRecord, + DNSService, + DNSText, +) +from ._exceptions import ( + AbstractMethodException, + BadTypeInNameException, + Error, + EventLoopBlocked, + IncomingDecodeError, + NamePartTooLongException, + NonUniqueNameException, + NotRunningException, + ServiceNameAlreadyRegistered, +) +from ._logger import QuietLogger, log # noqa # import needed for backwards compat +from ._protocol.incoming import DNSIncoming # noqa # import needed for backwards compat +from ._protocol.outgoing import DNSOutgoing # noqa # import needed for backwards compat +from ._record_update import RecordUpdate +from ._services import ( # noqa # import needed for backwards compat + ServiceListener, + ServiceStateChange, + Signal, + SignalRegistrationInterface, +) +from ._services.browser import ServiceBrowser +from ._services.info import ( # noqa # import needed for backwards compat + ServiceInfo, + instance_name_from_service_info, +) +from ._services.registry import ( # noqa # import needed for backwards compat + ServiceRegistry, +) +from ._services.types import ZeroconfServiceTypes +from ._updates import RecordUpdateListener +from ._utils.name import service_type_name # noqa # import needed for backwards compat +from ._utils.net import ( # noqa # import needed for backwards compat + InterfaceChoice, + InterfacesType, + IPVersion, + add_multicast_member, + autodetect_ip_version, + create_sockets, + get_all_addresses, + get_all_addresses_v6, +) +from ._utils.time import ( # noqa # import needed for backwards compat + current_time_millis, + millis_to_seconds, +) + +__author__ = 'Paul Scott-Murphy, William McBrine' +__maintainer__ = 'Jakub Stasiak ' +__version__ = '0.131.0' +__license__ = 'LGPL' + + +__all__ = [ + "__version__", + "Zeroconf", + "ServiceInfo", + "ServiceBrowser", + "ServiceListener", + "DNSQuestionType", + "InterfaceChoice", + "ServiceStateChange", + "IPVersion", + "ZeroconfServiceTypes", + "RecordUpdate", + "RecordUpdateListener", + "current_time_millis", + # Exceptions + "Error", + "AbstractMethodException", + "BadTypeInNameException", + "EventLoopBlocked", + "IncomingDecodeError", + "NamePartTooLongException", + "NonUniqueNameException", + "NotRunningException", + "ServiceNameAlreadyRegistered", +] + +if sys.version_info <= (3, 6): # pragma: no cover + raise ImportError( # pragma: no cover + ''' +Python version > 3.6 required for python-zeroconf. +If you need support for Python 2 or Python 3.3-3.4 please use version 19.1 +If you need support for Python 3.5 please use version 0.28.0 + ''' + ) diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py new file mode 100644 index 00000000..35a13cf6 --- /dev/null +++ b/zeroconf/_cache.py @@ -0,0 +1,249 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast + +from ._dns import ( + DNSAddress, + DNSEntry, + DNSHinfo, + DNSNsec, + DNSPointer, + DNSRecord, + DNSService, + DNSText, +) +from ._utils.time import current_time_millis +from .const import _ONE_SECOND, _TYPE_PTR + +_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService) +_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService] +_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] +_DNSRecord = DNSRecord +_str = str +_float = float +_int = int + + +def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None: + """Remove a key from a DNSRecord cache + + This function must be run in from event loop. + """ + del cache[key][record] + if not cache[key]: + del cache[key] + + +class DNSCache: + """A cache of DNS entries.""" + + def __init__(self) -> None: + self.cache: _DNSRecordCacheType = {} + self.service_cache: _DNSRecordCacheType = {} + + # Functions prefixed with async_ are NOT threadsafe and must + # be run in the event loop. + + def _async_add(self, record: _DNSRecord) -> bool: + """Adds an entry. + + Returns true if the entry was not already in the cache. + + This function must be run in from event loop. + """ + # Previously storage of records was implemented as a list + # instead a dict. Since DNSRecords are now hashable, the implementation + # uses a dict to ensure that adding a new record to the cache + # replaces any existing records that are __eq__ to each other which + # removes the risk that accessing the cache from the wrong + # direction would return the old incorrect entry. + store = self.cache.setdefault(record.key, {}) + new = record not in store and not isinstance(record, DNSNsec) + store[record] = record + if isinstance(record, DNSService): + self.service_cache.setdefault(record.server_key, {})[record] = record + return new + + def async_add_records(self, entries: Iterable[DNSRecord]) -> bool: + """Add multiple records. + + Returns true if any of the records were not in the cache. + + This function must be run in from event loop. + """ + new = False + for entry in entries: + if self._async_add(entry): + new = True + return new + + def _async_remove(self, record: _DNSRecord) -> None: + """Removes an entry. + + This function must be run in from event loop. + """ + if isinstance(record, DNSService): + _remove_key(self.service_cache, record.server_key, record) + _remove_key(self.cache, record.key, record) + + def async_remove_records(self, entries: Iterable[DNSRecord]) -> None: + """Remove multiple records. + + This function must be run in from event loop. + """ + for entry in entries: + self._async_remove(entry) + + def async_expire(self, now: _float) -> List[DNSRecord]: + """Purge expired entries from the cache. + + This function must be run in from event loop. + """ + expired = [record for records in self.cache.values() for record in records if record.is_expired(now)] + self.async_remove_records(expired) + return expired + + def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: + """Gets a unique entry by key. Will return None if there is no + matching entry. + + This function is not threadsafe and must be called from + the event loop. + """ + store = self.cache.get(entry.key) + if store is None: + return None + return store.get(entry) + + def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DNSRecord]: + """Gets all matching entries by details. + + This function is not thread-safe and must be called from + the event loop. + """ + key = name.lower() + records = self.cache.get(key) + matches: List[DNSRecord] = [] + if records is None: + return matches + for record in records: + if type_ == record.type and class_ == record.class_: + matches.append(record) + return matches + + def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]: + """Returns a dict of entries whose key matches the name. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.cache.get(name.lower()) or {} + + def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]: + """Returns a dict of entries whose key matches the server. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.service_cache.get(name.lower()) or {} + + # The below functions are threadsafe and do not need to be run in the + # event loop, however they all make copies so they significantly + # inefficent + + def get(self, entry: DNSEntry) -> Optional[DNSRecord]: + """Gets an entry by key. Will return None if there is no + matching entry.""" + if isinstance(entry, _UNIQUE_RECORD_TYPES): + return self.cache.get(entry.key, {}).get(entry) + for cached_entry in reversed(list(self.cache.get(entry.key, []))): + if entry.__eq__(cached_entry): + return cached_entry + return None + + def get_by_details(self, name: str, type_: _int, class_: _int) -> Optional[DNSRecord]: + """Gets the first matching entry by details. Returns None if no entries match. + + Calling this function is not recommended as it will only + return one record even if there are multiple entries. + + For example if there are multiple A or AAAA addresses this + function will return the last one that was added to the cache + which may not be the one you expect. + + Use get_all_by_details instead. + """ + key = name.lower() + records = self.cache.get(key) + if records is None: + return None + for cached_entry in reversed(list(records)): + if type_ == cached_entry.type and class_ == cached_entry.class_: + return cached_entry + return None + + def get_all_by_details(self, name: str, type_: _int, class_: _int) -> List[DNSRecord]: + """Gets all matching entries by details.""" + key = name.lower() + records = self.cache.get(key) + if records is None: + return [] + return [entry for entry in list(records) if type_ == entry.type and class_ == entry.class_] + + def entries_with_server(self, server: str) -> List[DNSRecord]: + """Returns a list of entries whose server matches the name.""" + return list(self.service_cache.get(server.lower(), [])) + + def entries_with_name(self, name: str) -> List[DNSRecord]: + """Returns a list of entries whose key matches the name.""" + return list(self.cache.get(name.lower(), [])) + + def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: + now = current_time_millis() + for record in reversed(self.entries_with_name(name)): + if ( + record.type == _TYPE_PTR + and not record.is_expired(now) + and cast(DNSPointer, record).alias == alias + ): + return record + return None + + def names(self) -> List[str]: + """Return a copy of the list of current cache names.""" + return list(self.cache) + + def async_mark_unique_records_older_than_1s_to_expire( + self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float + ) -> None: + # rfc6762#section-10.2 para 2 + # Since unique is set, all old records with that name, rrtype, + # and rrclass that were received more than one second ago are declared + # invalid, and marked to expire from the cache in one second. + answers_rrset = set(answers) + for name, type_, class_ in unique_types: + for record in self.async_all_by_details(name, type_, class_): + created_double = record.created + if (now - created_double > _ONE_SECOND) and record not in answers_rrset: + # Expire in 1s + record.set_created_ttl(now, 1) diff --git a/zeroconf/_core.py b/zeroconf/_core.py new file mode 100644 index 00000000..4b29717a --- /dev/null +++ b/zeroconf/_core.py @@ -0,0 +1,658 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import logging +import sys +import threading +from types import TracebackType +from typing import Awaitable, Dict, List, Optional, Set, Tuple, Type, Union + +from ._cache import DNSCache +from ._dns import DNSQuestion, DNSQuestionType +from ._engine import AsyncEngine +from ._exceptions import NonUniqueNameException, NotRunningException +from ._handlers.multicast_outgoing_queue import MulticastOutgoingQueue +from ._handlers.query_handler import QueryHandler +from ._handlers.record_manager import RecordManager +from ._history import QuestionHistory +from ._logger import QuietLogger, log +from ._protocol.outgoing import DNSOutgoing +from ._services import ServiceListener +from ._services.browser import ServiceBrowser +from ._services.info import ServiceInfo, instance_name_from_service_info +from ._services.registry import ServiceRegistry +from ._transport import _WrappedTransport +from ._updates import RecordUpdateListener +from ._utils.asyncio import ( + _resolve_all_futures_to_none, + await_awaitable, + get_running_loop, + run_coro_with_timeout, + shutdown_loop, + wait_event_or_timeout, + wait_for_future_set_or_timeout, +) +from ._utils.name import service_type_name +from ._utils.net import ( + InterfaceChoice, + InterfacesType, + IPVersion, + autodetect_ip_version, + can_send_to, + create_sockets, +) +from ._utils.time import current_time_millis, millis_to_seconds +from .const import ( + _CHECK_TIME, + _CLASS_IN, + _CLASS_UNIQUE, + _FLAGS_AA, + _FLAGS_QR_QUERY, + _FLAGS_QR_RESPONSE, + _MAX_MSG_ABSOLUTE, + _MDNS_ADDR, + _MDNS_ADDR6, + _MDNS_PORT, + _ONE_SECOND, + _REGISTER_TIME, + _STARTUP_TIMEOUT, + _TYPE_PTR, + _UNREGISTER_TIME, +) + +# The maximum amont of time to delay a multicast +# response in order to aggregate answers +_AGGREGATION_DELAY = 500 # ms +# The maximum amont of time to delay a multicast +# response in order to aggregate answers after +# it has already been delayed to protect the network +# from excessive traffic. We use a shorter time +# window here as we want to _try_ to answer all +# queries in under 1350ms while protecting +# the network from excessive traffic to ensure +# a service info request with two questions +# can be answered in the default timeout of +# 3000ms +_PROTECTED_AGGREGATION_DELAY = 200 # ms + +_REGISTER_BROADCASTS = 3 + + +def async_send_with_transport( + log_debug: bool, + transport: _WrappedTransport, + packet: bytes, + packet_num: int, + out: DNSOutgoing, + addr: Optional[str], + port: int, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), +) -> None: + ipv6_socket = transport.is_ipv6 + if addr is None: + real_addr = _MDNS_ADDR6 if ipv6_socket else _MDNS_ADDR + else: + real_addr = addr + if not can_send_to(ipv6_socket, real_addr): + return + if log_debug: + log.debug( + 'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...', + real_addr, + port or _MDNS_PORT, + transport.fileno, + transport.sock_name, + len(packet), + packet_num + 1, + out, + packet, + ) + # Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6 + # address tuple: https://docs.python.org/3.6/library/socket.html#socket-families + if ipv6_socket and not v6_flow_scope: + _, _, sock_flowinfo, sock_scopeid = transport.sock_name + v6_flow_scope = (sock_flowinfo, sock_scopeid) + transport.transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) + + +class Zeroconf(QuietLogger): + + """Implementation of Zeroconf Multicast DNS Service Discovery + + Supports registration, unregistration, queries and browsing. + """ + + def __init__( + self, + interfaces: InterfacesType = InterfaceChoice.All, + unicast: bool = False, + ip_version: Optional[IPVersion] = None, + apple_p2p: bool = False, + ) -> None: + """Creates an instance of the Zeroconf class, establishing + multicast communications, listening and reaping threads. + + :param interfaces: :class:`InterfaceChoice` or a list of IP addresses + (IPv4 and IPv6) and interface indexes (IPv6 only). + + IPv6 notes for non-POSIX systems: + * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` + on Python versions before 3.8. + + Also listening on loopback (``::1``) doesn't work, use a real address. + :param ip_version: IP versions to support. If `choice` is a list, the default is detected + from it. Otherwise defaults to V4 only for backward compatibility. + :param apple_p2p: use AWDL interface (only macOS) + """ + if ip_version is None: + ip_version = autodetect_ip_version(interfaces) + + self.done = False + + if apple_p2p and sys.platform != 'darwin': + raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.') + + self.unicast = unicast + listen_socket, respond_sockets = create_sockets(interfaces, unicast, ip_version, apple_p2p=apple_p2p) + log.debug('Listen socket %s, respond sockets %s', listen_socket, respond_sockets) + + self.engine = AsyncEngine(self, listen_socket, respond_sockets) + + self.browsers: Dict[ServiceListener, ServiceBrowser] = {} + self.registry = ServiceRegistry() + self.cache = DNSCache() + self.question_history = QuestionHistory() + + self.out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY) + self.out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY) + + self.query_handler = QueryHandler(self) + self.record_manager = RecordManager(self) + + self._notify_futures: Set[asyncio.Future] = set() + self.loop: Optional[asyncio.AbstractEventLoop] = None + self._loop_thread: Optional[threading.Thread] = None + + self.start() + + @property + def started(self) -> bool: + """Check if the instance has started.""" + return bool(not self.done and self.engine.running_event and self.engine.running_event.is_set()) + + def start(self) -> None: + """Start Zeroconf.""" + self.loop = get_running_loop() + if self.loop: + self.engine.setup(self.loop, None) + return + self._start_thread() + + def _start_thread(self) -> None: + """Start a thread with a running event loop.""" + loop_thread_ready = threading.Event() + + def _run_loop() -> None: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.engine.setup(self.loop, loop_thread_ready) + self.loop.run_forever() + + self._loop_thread = threading.Thread(target=_run_loop, daemon=True) + self._loop_thread.start() + loop_thread_ready.wait() + + async def async_wait_for_start(self) -> None: + """Wait for start up for actions that require a running Zeroconf instance. + + Throws NotRunningException if the instance is not running or could + not be started. + """ + if self.done: # If the instance was shutdown from under us, raise immediately + raise NotRunningException + assert self.engine.running_event is not None + await wait_event_or_timeout(self.engine.running_event, timeout=_STARTUP_TIMEOUT) + if not self.engine.running_event.is_set() or self.done: + raise NotRunningException + + @property + def listeners(self) -> Set[RecordUpdateListener]: + return self.record_manager.listeners + + async def async_wait(self, timeout: float) -> None: + """Calling task waits for a given number of milliseconds or until notified.""" + loop = self.loop + assert loop is not None + await wait_for_future_set_or_timeout(loop, self._notify_futures, timeout) + + def notify_all(self) -> None: + """Notifies all waiting threads and notify listeners.""" + assert self.loop is not None + self.loop.call_soon_threadsafe(self.async_notify_all) + + def async_notify_all(self) -> None: + """Schedule an async_notify_all.""" + notify_futures = self._notify_futures + if notify_futures: + _resolve_all_futures_to_none(notify_futures) + + def get_service_info( + self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None + ) -> Optional[ServiceInfo]: + """Returns network's service information for a particular + name and type, or None if no service matches by the timeout, + which defaults to 3 seconds.""" + info = ServiceInfo(type_, name) + if info.request(self, timeout, question_type): + return info + return None + + def add_service_listener(self, type_: str, listener: ServiceListener) -> None: + """Adds a listener for a particular service type. This object + will then have its add_service and remove_service methods called when + services of that type become available and unavailable.""" + self.remove_service_listener(listener) + self.browsers[listener] = ServiceBrowser(self, type_, listener) + + def remove_service_listener(self, listener: ServiceListener) -> None: + """Removes a listener from the set that is currently listening.""" + if listener in self.browsers: + self.browsers[listener].cancel() + del self.browsers[listener] + + def remove_all_service_listeners(self) -> None: + """Removes a listener from the set that is currently listening.""" + for listener in list(self.browsers): + self.remove_service_listener(listener) + + def register_service( + self, + info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, + cooperating_responders: bool = False, + strict: bool = True, + ) -> None: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. The name of the service may be changed if needed to make + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`). + + While it is not expected during normal operation, + this function may raise EventLoopBlocked if the underlying + call to `register_service` cannot be completed. + """ + assert self.loop is not None + run_coro_with_timeout( + await_awaitable( + self.async_register_service(info, ttl, allow_name_change, cooperating_responders, strict) + ), + self.loop, + _REGISTER_TIME * _REGISTER_BROADCASTS, + ) + + async def async_register_service( + self, + info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, + cooperating_responders: bool = False, + strict: bool = True, + ) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. The name of the service may be changed if needed to make + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`).""" + if ttl is not None: + # ttl argument is used to maintain backward compatibility + # Setting TTLs via ServiceInfo is preferred + info.host_ttl = ttl + info.other_ttl = ttl + + info.set_server_if_missing() + await self.async_wait_for_start() + await self.async_check_service(info, allow_name_change, cooperating_responders, strict) + self.registry.async_add(info) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + + def update_service(self, info: ServiceInfo) -> None: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. + + While it is not expected during normal operation, + this function may raise EventLoopBlocked if the underlying + call to `async_update_service` cannot be completed. + """ + assert self.loop is not None + run_coro_with_timeout( + await_awaitable(self.async_update_service(info)), self.loop, _REGISTER_TIME * _REGISTER_BROADCASTS + ) + + async def async_update_service(self, info: ServiceInfo) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service.""" + self.registry.async_update(info) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + + async def _async_broadcast_service( + self, + info: ServiceInfo, + interval: int, + ttl: Optional[int], + broadcast_addresses: bool = True, + ) -> None: + """Send a broadcasts to announce a service at intervals.""" + for i in range(_REGISTER_BROADCASTS): + if i != 0: + await asyncio.sleep(millis_to_seconds(interval)) + self.async_send(self.generate_service_broadcast(info, ttl, broadcast_addresses)) + + def generate_service_broadcast( + self, + info: ServiceInfo, + ttl: Optional[int], + broadcast_addresses: bool = True, + ) -> DNSOutgoing: + """Generate a broadcast to announce a service.""" + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + self._add_broadcast_answer(out, info, ttl, broadcast_addresses) + return out + + def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use + """Generate a query to lookup a service.""" + out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) + # https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 + # Because of the mDNS multicast rate-limiting + # rules, the probes SHOULD be sent as "QU" questions with the unicast- + # response bit set, to allow a defending host to respond immediately + # via unicast, instead of potentially having to wait before replying + # via multicast. + # + # _CLASS_UNIQUE is the "QU" bit + out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN | _CLASS_UNIQUE)) + out.add_authorative_answer(info.dns_pointer()) + return out + + def _add_broadcast_answer( # pylint: disable=no-self-use + self, + out: DNSOutgoing, + info: ServiceInfo, + override_ttl: Optional[int], + broadcast_addresses: bool = True, + ) -> None: + """Add answers to broadcast a service.""" + current_time_millis() + other_ttl = None if override_ttl is None else override_ttl + host_ttl = None if override_ttl is None else override_ttl + out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0) + out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0) + out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0) + if broadcast_addresses: + for record in info.get_address_and_nsec_records(override_ttl=host_ttl): + out.add_answer_at_time(record, 0) + + def unregister_service(self, info: ServiceInfo) -> None: + """Unregister a service. + + While it is not expected during normal operation, + this function may raise EventLoopBlocked if the underlying + call to `async_unregister_service` cannot be completed. + """ + assert self.loop is not None + run_coro_with_timeout( + self.async_unregister_service(info), self.loop, _UNREGISTER_TIME * _REGISTER_BROADCASTS + ) + + async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: + """Unregister a service.""" + info.set_server_if_missing() + self.registry.async_remove(info) + # If another server uses the same addresses, we do not want to send + # goodbye packets for the address records + + assert info.server_key is not None + entries = self.registry.async_get_infos_server(info.server_key) + broadcast_addresses = not bool(entries) + return asyncio.ensure_future( + self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses) + ) + + def generate_unregister_all_services(self) -> Optional[DNSOutgoing]: + """Generate a DNSOutgoing goodbye for all services and remove them from the registry.""" + service_infos = self.registry.async_get_service_infos() + if not service_infos: + return None + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + for info in service_infos: + self._add_broadcast_answer(out, info, 0) + self.registry.async_remove(service_infos) + return out + + async def async_unregister_all_services(self) -> None: + """Unregister all registered services. + + Unlike async_register_service and async_unregister_service, this + method does not return a future and is always expected to be + awaited since its only called at shutdown. + """ + # Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1 + out = self.generate_unregister_all_services() + if not out: + return + for i in range(_REGISTER_BROADCASTS): + if i != 0: + await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME)) + self.async_send(out) + + def unregister_all_services(self) -> None: + """Unregister all registered services. + + While it is not expected during normal operation, + this function may raise EventLoopBlocked if the underlying + call to `async_unregister_all_services` cannot be completed. + """ + assert self.loop is not None + run_coro_with_timeout( + self.async_unregister_all_services(), self.loop, _UNREGISTER_TIME * _REGISTER_BROADCASTS + ) + + async def async_check_service( + self, + info: ServiceInfo, + allow_name_change: bool, + cooperating_responders: bool = False, + strict: bool = True, + ) -> None: + """Checks the network for a unique service name, modifying the + ServiceInfo passed in if it is not unique.""" + instance_name = instance_name_from_service_info(info, strict=strict) + if cooperating_responders: + return + next_instance_number = 2 + next_time = now = current_time_millis() + i = 0 + while i < _REGISTER_BROADCASTS: + # check for a name conflict + while self.cache.current_entry_with_name_and_alias(info.type, info.name): + if not allow_name_change: + raise NonUniqueNameException + + # change the name and look for a conflict + info.name = f'{instance_name}-{next_instance_number}.{info.type}' + next_instance_number += 1 + service_type_name(info.name, strict=strict) + next_time = now + i = 0 + + if now < next_time: + await self.async_wait(next_time - now) + now = current_time_millis() + continue + + self.async_send(self.generate_service_query(info)) + i += 1 + next_time += _CHECK_TIME + + def add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: + """Adds a listener for a given question. The listener will have + its update_record method called when information is available to + answer the question(s). + + This function is threadsafe + """ + assert self.loop is not None + self.loop.call_soon_threadsafe(self.record_manager.async_add_listener, listener, question) + + def remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener. + + This function is threadsafe + """ + assert self.loop is not None + self.loop.call_soon_threadsafe(self.record_manager.async_remove_listener, listener) + + def async_add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: + """Adds a listener for a given question. The listener will have + its update_record method called when information is available to + answer the question(s). + + This function is not threadsafe and must be called in the eventloop. + """ + self.record_manager.async_add_listener(listener, question) + + def async_remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener. + + This function is not threadsafe and must be called in the eventloop. + """ + self.record_manager.async_remove_listener(listener) + + def send( + self, + out: DNSOutgoing, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + transport: Optional[_WrappedTransport] = None, + ) -> None: + """Sends an outgoing packet threadsafe.""" + assert self.loop is not None + self.loop.call_soon_threadsafe(self.async_send, out, addr, port, v6_flow_scope, transport) + + def async_send( + self, + out: DNSOutgoing, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + transport: Optional[_WrappedTransport] = None, + ) -> None: + """Sends an outgoing packet.""" + if self.done: + return + + # If no transport is specified, we send to all the ones + # with the same address family + transports = [transport] if transport else self.engine.senders + log_debug = log.isEnabledFor(logging.DEBUG) + + for packet_num, packet in enumerate(out.packets()): + if len(packet) > _MAX_MSG_ABSOLUTE: + self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) + return + for send_transport in transports: + async_send_with_transport( + log_debug, send_transport, packet, packet_num, out, addr, port, v6_flow_scope + ) + + def _close(self) -> None: + """Set global done and remove all service listeners.""" + if self.done: + return + self.remove_all_service_listeners() + self.done = True + + def _shutdown_threads(self) -> None: + """Shutdown any threads.""" + self.notify_all() + if not self._loop_thread: + return + assert self.loop is not None + shutdown_loop(self.loop) + self._loop_thread.join() + self._loop_thread = None + + def close(self) -> None: + """Ends the background threads, and prevent this instance from + servicing further queries. + + This method is idempotent and irreversible. + """ + assert self.loop is not None + if self.loop.is_running(): + if self.loop == get_running_loop(): + log.warning( + "unregister_all_services skipped as it does blocking i/o; use AsyncZeroconf with asyncio" + ) + else: + self.unregister_all_services() + self._close() + self.engine.close() + self._shutdown_threads() + + async def _async_close(self) -> None: + """Ends the background threads, and prevent this instance from + servicing further queries. + + This method is idempotent and irreversible. + + This call only intended to be used by AsyncZeroconf + + Callers are responsible for unregistering all services + before calling this function + """ + self._close() + await self.engine._async_close() # pylint: disable=protected-access + self._shutdown_threads() + + def __enter__(self) -> 'Zeroconf': + return self + + def __exit__( # pylint: disable=useless-return + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self.close() + return None diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py new file mode 100644 index 00000000..66fb5b86 --- /dev/null +++ b/zeroconf/_dns.py @@ -0,0 +1,559 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import enum +import socket +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast + +from ._exceptions import AbstractMethodException +from ._utils.net import _is_v6_address +from ._utils.time import current_time_millis +from .const import _CLASS_MASK, _CLASS_UNIQUE, _CLASSES, _TYPE_ANY, _TYPES + +_LEN_BYTE = 1 +_LEN_SHORT = 2 +_LEN_INT = 4 + +_BASE_MAX_SIZE = _LEN_SHORT + _LEN_SHORT + _LEN_INT + _LEN_SHORT # type # class # ttl # length +_NAME_COMPRESSION_MIN_SIZE = _LEN_BYTE * 2 + +_EXPIRE_FULL_TIME_MS = 1000 +_EXPIRE_STALE_TIME_MS = 500 +_RECENT_TIME_MS = 250 + +_float = float +_int = int + +if TYPE_CHECKING: + from ._protocol.incoming import DNSIncoming + from ._protocol.outgoing import DNSOutgoing + + +@enum.unique +class DNSQuestionType(enum.Enum): + """An MDNS question type. + + "QU" - questions requesting unicast responses + "QM" - questions requesting multicast responses + https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 + """ + + QU = 1 + QM = 2 + + +class DNSEntry: + + """A DNS entry""" + + __slots__ = ('key', 'name', 'type', 'class_', 'unique') + + def __init__(self, name: str, type_: int, class_: int) -> None: + self.name = name + self.key = name.lower() + self.type = type_ + self._set_class(class_) + + def _set_class(self, class_: _int) -> None: + self.class_ = class_ & _CLASS_MASK + self.unique = (class_ & _CLASS_UNIQUE) != 0 + + def _dns_entry_matches(self, other) -> bool: # type: ignore[no-untyped-def] + return self.key == other.key and self.type == other.type and self.class_ == other.class_ + + def __eq__(self, other: Any) -> bool: + """Equality test on key (lowercase name), type, and class""" + return isinstance(other, DNSEntry) and self._dns_entry_matches(other) + + @staticmethod + def get_class_(class_: int) -> str: + """Class accessor""" + return _CLASSES.get(class_, f"?({class_})") + + @staticmethod + def get_type(t: int) -> str: + """Type accessor""" + return _TYPES.get(t, f"?({t})") + + def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str: + """String representation with additional information""" + return "{}[{},{}{},{}]{}".format( + hdr, + self.get_type(self.type), + self.get_class_(self.class_), + "-unique" if self.unique else "", + self.name, + "=%s" % cast(Any, other) if other is not None else "", + ) + + +class DNSQuestion(DNSEntry): + + """A DNS question entry""" + + __slots__ = ('_hash',) + + def __init__(self, name: str, type_: int, class_: int) -> None: + super().__init__(name, type_, class_) + self._hash = hash((self.key, type_, self.class_)) + + def answered_by(self, rec: 'DNSRecord') -> bool: + """Returns true if the question is answered by the record""" + return self.class_ == rec.class_ and self.type in (rec.type, _TYPE_ANY) and self.name == rec.name + + def __hash__(self) -> int: + return self._hash + + def __eq__(self, other: Any) -> bool: + """Tests equality on dns question.""" + return isinstance(other, DNSQuestion) and self._dns_entry_matches(other) + + @property + def max_size(self) -> int: + """Maximum size of the question in the packet.""" + return len(self.name.encode('utf-8')) + _LEN_BYTE + _LEN_SHORT + _LEN_SHORT # type # class + + @property + def unicast(self) -> bool: + """Returns true if the QU (not QM) is set. + + unique shares the same mask as the one + used for unicast. + """ + return self.unique + + @unicast.setter + def unicast(self, value: bool) -> None: + """Sets the QU bit (not QM).""" + self.unique = value + + def __repr__(self) -> str: + """String representation""" + return "{}[question,{},{},{}]".format( + self.get_type(self.type), + "QU" if self.unicast else "QM", + self.get_class_(self.class_), + self.name, + ) + + +class DNSRecord(DNSEntry): + + """A DNS record - like a DNS entry, but has a TTL""" + + __slots__ = ('ttl', 'created') + + # TODO: Switch to just int ttl + def __init__( + self, name: str, type_: int, class_: int, ttl: Union[float, int], created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_) + self.ttl = ttl + self.created = created or current_time_millis() + + def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use + """Abstract method""" + raise AbstractMethodException + + def suppressed_by(self, msg: 'DNSIncoming') -> bool: + """Returns true if any answer in a message can suffice for the + information held in this record.""" + answers = msg.answers() + for record in answers: + if self._suppressed_by_answer(record): + return True + return False + + def _suppressed_by_answer(self, other) -> bool: # type: ignore[no-untyped-def] + """Returns true if another record has same name, type and class, + and if its TTL is at least half of this record's.""" + return self == other and other.ttl > (self.ttl / 2) + + def get_expiration_time(self, percent: _int) -> float: + """Returns the time at which this record will have expired + by a certain percentage.""" + return self.created + (percent * self.ttl * 10) + + # TODO: Switch to just int here + def get_remaining_ttl(self, now: _float) -> Union[int, float]: + """Returns the remaining TTL in seconds.""" + remain = (self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) - now) / 1000.0 + return 0 if remain < 0 else remain + + def is_expired(self, now: _float) -> bool: + """Returns true if this record has expired.""" + return self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) <= now + + def is_stale(self, now: _float) -> bool: + """Returns true if this record is at least half way expired.""" + return self.created + (_EXPIRE_STALE_TIME_MS * self.ttl) <= now + + def is_recent(self, now: _float) -> bool: + """Returns true if the record more than one quarter of its TTL remaining.""" + return self.created + (_RECENT_TIME_MS * self.ttl) > now + + def reset_ttl(self, other) -> None: # type: ignore[no-untyped-def] + """Sets this record's TTL and created time to that of + another record.""" + self.set_created_ttl(other.created, other.ttl) + + def set_created_ttl(self, created: _float, ttl: Union[float, int]) -> None: + """Set the created and ttl of a record.""" + self.created = created + self.ttl = ttl + + def write(self, out: 'DNSOutgoing') -> None: # pylint: disable=no-self-use + """Abstract method""" + raise AbstractMethodException + + def to_string(self, other: Union[bytes, str]) -> str: + """String representation with additional information""" + arg = f"{self.ttl}/{int(self.get_remaining_ttl(current_time_millis()))},{cast(Any, other)}" + return DNSEntry.entry_to_string(self, "record", arg) + + +class DNSAddress(DNSRecord): + + """A DNS address record""" + + __slots__ = ('_hash', 'address', 'scope_id') + + def __init__( + self, + name: str, + type_: int, + class_: int, + ttl: int, + address: bytes, + scope_id: Optional[int] = None, + created: Optional[float] = None, + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.address = address + self.scope_id = scope_id + self._hash = hash((self.key, type_, self.class_, address, scope_id)) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_string(self.address) + + def __eq__(self, other: Any) -> bool: + """Tests equality on address""" + return isinstance(other, DNSAddress) and self._eq(other) + + def _eq(self, other) -> bool: # type: ignore[no-untyped-def] + return ( + self.address == other.address + and self.scope_id == other.scope_id + and self._dns_entry_matches(other) + ) + + def __hash__(self) -> int: + """Hash to compare like DNSAddresses.""" + return self._hash + + def __repr__(self) -> str: + """String representation""" + try: + return self.to_string( + socket.inet_ntop( + socket.AF_INET6 if _is_v6_address(self.address) else socket.AF_INET, self.address + ) + ) + except (ValueError, OSError): + return self.to_string(str(self.address)) + + +class DNSHinfo(DNSRecord): + + """A DNS host information record""" + + __slots__ = ('_hash', 'cpu', 'os') + + def __init__( + self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.cpu = cpu + self.os = os + self._hash = hash((self.key, type_, self.class_, cpu, os)) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_character_string(self.cpu.encode('utf-8')) + out.write_character_string(self.os.encode('utf-8')) + + def __eq__(self, other: Any) -> bool: + """Tests equality on cpu and os.""" + return isinstance(other, DNSHinfo) and self._eq(other) + + def _eq(self, other) -> bool: # type: ignore[no-untyped-def] + """Tests equality on cpu and os.""" + return self.cpu == other.cpu and self.os == other.os and self._dns_entry_matches(other) + + def __hash__(self) -> int: + """Hash to compare like DNSHinfo.""" + return self._hash + + def __repr__(self) -> str: + """String representation""" + return self.to_string(self.cpu + " " + self.os) + + +class DNSPointer(DNSRecord): + + """A DNS pointer record""" + + __slots__ = ('_hash', 'alias', 'alias_key') + + def __init__( + self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.alias = alias + self.alias_key = alias.lower() + self._hash = hash((self.key, type_, self.class_, self.alias_key)) + + @property + def max_size_compressed(self) -> int: + """Maximum size of the record in the packet assuming the name has been compressed.""" + return ( + _BASE_MAX_SIZE + + _NAME_COMPRESSION_MIN_SIZE + + (len(self.alias) - len(self.name)) + + _NAME_COMPRESSION_MIN_SIZE + ) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_name(self.alias) + + def __eq__(self, other: Any) -> bool: + """Tests equality on alias.""" + return isinstance(other, DNSPointer) and self._eq(other) + + def _eq(self, other) -> bool: # type: ignore[no-untyped-def] + """Tests equality on alias.""" + return self.alias_key == other.alias_key and self._dns_entry_matches(other) + + def __hash__(self) -> int: + """Hash to compare like DNSPointer.""" + return self._hash + + def __repr__(self) -> str: + """String representation""" + return self.to_string(self.alias) + + +class DNSText(DNSRecord): + + """A DNS text record""" + + __slots__ = ('_hash', 'text') + + def __init__( + self, name: str, type_: int, class_: int, ttl: int, text: bytes, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.text = text + self._hash = hash((self.key, type_, self.class_, text)) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_string(self.text) + + def __hash__(self) -> int: + """Hash to compare like DNSText.""" + return self._hash + + def __eq__(self, other: Any) -> bool: + """Tests equality on text.""" + return isinstance(other, DNSText) and self._eq(other) + + def _eq(self, other) -> bool: # type: ignore[no-untyped-def] + """Tests equality on text.""" + return self.text == other.text and self._dns_entry_matches(other) + + def __repr__(self) -> str: + """String representation""" + if len(self.text) > 10: + return self.to_string(self.text[:7]) + "..." + return self.to_string(self.text) + + +class DNSService(DNSRecord): + + """A DNS service record""" + + __slots__ = ('_hash', 'priority', 'weight', 'port', 'server', 'server_key') + + def __init__( + self, + name: str, + type_: int, + class_: int, + ttl: Union[float, int], + priority: int, + weight: int, + port: int, + server: str, + created: Optional[float] = None, + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.priority = priority + self.weight = weight + self.port = port + self.server = server + self.server_key = server.lower() + self._hash = hash((self.key, type_, self.class_, priority, weight, port, self.server_key)) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_short(self.priority) + out.write_short(self.weight) + out.write_short(self.port) + out.write_name(self.server) + + def __eq__(self, other: Any) -> bool: + """Tests equality on priority, weight, port and server""" + return isinstance(other, DNSService) and self._eq(other) + + def _eq(self, other) -> bool: # type: ignore[no-untyped-def] + """Tests equality on priority, weight, port and server.""" + return ( + self.priority == other.priority + and self.weight == other.weight + and self.port == other.port + and self.server_key == other.server_key + and self._dns_entry_matches(other) + ) + + def __hash__(self) -> int: + """Hash to compare like DNSService.""" + return self._hash + + def __repr__(self) -> str: + """String representation""" + return self.to_string(f"{self.server}:{self.port}") + + +class DNSNsec(DNSRecord): + + """A DNS NSEC record""" + + __slots__ = ('_hash', 'next_name', 'rdtypes') + + def __init__( + self, + name: str, + type_: int, + class_: int, + ttl: int, + next_name: str, + rdtypes: List[int], + created: Optional[float] = None, + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.next_name = next_name + self.rdtypes = sorted(rdtypes) + self._hash = hash((self.key, type_, self.class_, next_name, *self.rdtypes)) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet.""" + bitmap = bytearray(b'\0' * 32) + total_octets = 0 + for rdtype in self.rdtypes: + if rdtype > 255: # mDNS only supports window 0 + raise ValueError(f"rdtype {rdtype} is too large for NSEC") + byte = rdtype // 8 + total_octets = byte + 1 + bitmap[byte] |= 0x80 >> (rdtype % 8) + if total_octets == 0: + # NSEC must have at least one rdtype + # Writing an empty bitmap is not allowed + raise ValueError("NSEC must have at least one rdtype") + out_bytes = bytes(bitmap[0:total_octets]) + out.write_name(self.next_name) + out._write_byte(0) # Always window 0 + out._write_byte(len(out_bytes)) + out.write_string(out_bytes) + + def __eq__(self, other: Any) -> bool: + """Tests equality on next_name and rdtypes.""" + return isinstance(other, DNSNsec) and self._eq(other) + + def _eq(self, other) -> bool: # type: ignore[no-untyped-def] + """Tests equality on next_name and rdtypes.""" + return ( + self.next_name == other.next_name + and self.rdtypes == other.rdtypes + and self._dns_entry_matches(other) + ) + + def __hash__(self) -> int: + """Hash to compare like DNSNSec.""" + return self._hash + + def __repr__(self) -> str: + """String representation""" + return self.to_string( + self.next_name + "," + "|".join([self.get_type(type_) for type_ in self.rdtypes]) + ) + + +_DNSRecord = DNSRecord + + +class DNSRRSet: + """A set of dns records with a lookup to get the ttl.""" + + __slots__ = ('_records', '_lookup') + + def __init__(self, records: List[DNSRecord]) -> None: + """Create an RRset from records sets.""" + self._records = records + self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None + + @property + def lookup(self) -> Dict[DNSRecord, DNSRecord]: + """Return the lookup table.""" + return self._get_lookup() + + def lookup_set(self) -> Set[DNSRecord]: + """Return the lookup table as aset.""" + return set(self._get_lookup()) + + def _get_lookup(self) -> Dict[DNSRecord, DNSRecord]: + """Return the lookup table, building it if needed.""" + if self._lookup is None: + # Build the hash table so we can lookup the record ttl + self._lookup = {record: record for record in self._records} + return self._lookup + + def suppresses(self, record: _DNSRecord) -> bool: + """Returns true if any answer in the rrset can suffice for the + information held in this record.""" + lookup = self._get_lookup() + other = lookup.get(record) + if other is None: + return False + return other.ttl > (record.ttl / 2) diff --git a/zeroconf/_engine.py b/zeroconf/_engine.py new file mode 100644 index 00000000..9e455003 --- /dev/null +++ b/zeroconf/_engine.py @@ -0,0 +1,156 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import itertools +import socket +import threading +from typing import TYPE_CHECKING, List, Optional, cast + +from ._record_update import RecordUpdate +from ._utils.asyncio import get_running_loop, run_coro_with_timeout +from ._utils.time import current_time_millis +from .const import _CACHE_CLEANUP_INTERVAL + +if TYPE_CHECKING: + from ._core import Zeroconf + + +from ._listener import AsyncListener +from ._transport import _WrappedTransport, make_wrapped_transport + +_CLOSE_TIMEOUT = 3000 # ms + + +class AsyncEngine: + """An engine wraps sockets in the event loop.""" + + __slots__ = ( + 'loop', + 'zc', + 'protocols', + 'readers', + 'senders', + 'running_event', + '_listen_socket', + '_respond_sockets', + '_cleanup_timer', + ) + + def __init__( + self, + zeroconf: 'Zeroconf', + listen_socket: Optional[socket.socket], + respond_sockets: List[socket.socket], + ) -> None: + self.loop: Optional[asyncio.AbstractEventLoop] = None + self.zc = zeroconf + self.protocols: List[AsyncListener] = [] + self.readers: List[_WrappedTransport] = [] + self.senders: List[_WrappedTransport] = [] + self.running_event: Optional[asyncio.Event] = None + self._listen_socket = listen_socket + self._respond_sockets = respond_sockets + self._cleanup_timer: Optional[asyncio.TimerHandle] = None + + def setup(self, loop: asyncio.AbstractEventLoop, loop_thread_ready: Optional[threading.Event]) -> None: + """Set up the instance.""" + self.loop = loop + self.running_event = asyncio.Event() + self.loop.create_task(self._async_setup(loop_thread_ready)) + + async def _async_setup(self, loop_thread_ready: Optional[threading.Event]) -> None: + """Set up the instance.""" + self._async_schedule_next_cache_cleanup() + await self._async_create_endpoints() + assert self.running_event is not None + self.running_event.set() + if loop_thread_ready: + loop_thread_ready.set() + + async def _async_create_endpoints(self) -> None: + """Create endpoints to send and receive.""" + assert self.loop is not None + loop = self.loop + reader_sockets = [] + sender_sockets = [] + if self._listen_socket: + reader_sockets.append(self._listen_socket) + for s in self._respond_sockets: + if s not in reader_sockets: + reader_sockets.append(s) + sender_sockets.append(s) + + for s in reader_sockets: + transport, protocol = await loop.create_datagram_endpoint( + lambda: AsyncListener(self.zc), sock=s # type: ignore[arg-type, return-value] + ) + self.protocols.append(cast(AsyncListener, protocol)) + self.readers.append(make_wrapped_transport(cast(asyncio.DatagramTransport, transport))) + if s in sender_sockets: + self.senders.append(make_wrapped_transport(cast(asyncio.DatagramTransport, transport))) + + def _async_cache_cleanup(self) -> None: + """Periodic cache cleanup.""" + now = current_time_millis() + self.zc.question_history.async_expire(now) + self.zc.record_manager.async_updates( + now, [RecordUpdate(record, record) for record in self.zc.cache.async_expire(now)] + ) + self.zc.record_manager.async_updates_complete(False) + self._async_schedule_next_cache_cleanup() + + def _async_schedule_next_cache_cleanup(self) -> None: + """Schedule the next cache cleanup.""" + loop = self.loop + assert loop is not None + self._cleanup_timer = loop.call_at(loop.time() + _CACHE_CLEANUP_INTERVAL, self._async_cache_cleanup) + + async def _async_close(self) -> None: + """Cancel and wait for the cleanup task to finish.""" + self._async_shutdown() + await asyncio.sleep(0) # flush out any call soons + assert self._cleanup_timer is not None + self._cleanup_timer.cancel() + + def _async_shutdown(self) -> None: + """Shutdown transports and sockets.""" + assert self.running_event is not None + self.running_event.clear() + for wrapped_transport in itertools.chain(self.senders, self.readers): + wrapped_transport.transport.close() + + def close(self) -> None: + """Close from sync context. + + While it is not expected during normal operation, + this function may raise EventLoopBlocked if the underlying + call to `_async_close` cannot be completed. + """ + assert self.loop is not None + # Guard against Zeroconf.close() being called from the eventloop + if get_running_loop() == self.loop: + self._async_shutdown() + return + if not self.loop.is_running(): + return + run_coro_with_timeout(self._async_close(), self.loop, _CLOSE_TIMEOUT) diff --git a/zeroconf/_exceptions.py b/zeroconf/_exceptions.py new file mode 100644 index 00000000..f4fcbd55 --- /dev/null +++ b/zeroconf/_exceptions.py @@ -0,0 +1,67 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + + +class Error(Exception): + """Base class for all zeroconf exceptions.""" + + +class IncomingDecodeError(Error): + """Exception when there is invalid data in an incoming packet.""" + + +class NonUniqueNameException(Error): + """Exception when the name is already registered.""" + + +class NamePartTooLongException(Error): + """Exception when the name is too long.""" + + +class AbstractMethodException(Error): + """Exception when a required method is not implemented.""" + + +class BadTypeInNameException(Error): + """Exception when the type in a name is invalid.""" + + +class ServiceNameAlreadyRegistered(Error): + """Exception when a service name is already registered.""" + + +class EventLoopBlocked(Error): + """Exception when the event loop is blocked. + + This exception is never expected to be thrown + during normal operation. It should only happen + when the cpu is maxed out or there is something blocking + the event loop. + """ + + +class NotRunningException(Error): + """Exception when an action is called with a zeroconf instance that is not running. + + The instance may not be running because it was already shutdown + or startup has failed in some unexpected way. + """ diff --git a/zeroconf/_handlers/__init__.py b/zeroconf/_handlers/__init__.py new file mode 100644 index 00000000..2ef4b15b --- /dev/null +++ b/zeroconf/_handlers/__init__.py @@ -0,0 +1,21 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" diff --git a/zeroconf/_handlers/answers.py b/zeroconf/_handlers/answers.py new file mode 100644 index 00000000..a2dbd66a --- /dev/null +++ b/zeroconf/_handlers/answers.py @@ -0,0 +1,114 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from operator import attrgetter +from typing import Dict, List, Set + +from .._dns import DNSQuestion, DNSRecord +from .._protocol.outgoing import DNSOutgoing +from ..const import _FLAGS_AA, _FLAGS_QR_RESPONSE + +_AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]] + +int_ = int + + +MULTICAST_DELAY_RANDOM_INTERVAL = (20, 120) + +NAME_GETTER = attrgetter('name') + +_FLAGS_QR_RESPONSE_AA = _FLAGS_QR_RESPONSE | _FLAGS_AA + +float_ = float + + +class QuestionAnswers: + """A group of answers to a question.""" + + __slots__ = ('ucast', 'mcast_now', 'mcast_aggregate', 'mcast_aggregate_last_second') + + def __init__( + self, + ucast: _AnswerWithAdditionalsType, + mcast_now: _AnswerWithAdditionalsType, + mcast_aggregate: _AnswerWithAdditionalsType, + mcast_aggregate_last_second: _AnswerWithAdditionalsType, + ) -> None: + """Initialize a QuestionAnswers.""" + self.ucast = ucast + self.mcast_now = mcast_now + self.mcast_aggregate = mcast_aggregate + self.mcast_aggregate_last_second = mcast_aggregate_last_second + + def __repr__(self) -> str: + """Return a string representation of this QuestionAnswers.""" + return ( + f'QuestionAnswers(ucast={self.ucast}, mcast_now={self.mcast_now}, ' + f'mcast_aggregate={self.mcast_aggregate}, ' + f'mcast_aggregate_last_second={self.mcast_aggregate_last_second})' + ) + + +class AnswerGroup: + """A group of answers scheduled to be sent at the same time.""" + + __slots__ = ('send_after', 'send_before', 'answers') + + def __init__(self, send_after: float_, send_before: float_, answers: _AnswerWithAdditionalsType) -> None: + self.send_after = send_after # Must be sent after this time + self.send_before = send_before # Must be sent before this time + self.answers = answers + + +def construct_outgoing_multicast_answers(answers: _AnswerWithAdditionalsType) -> DNSOutgoing: + """Add answers and additionals to a DNSOutgoing.""" + out = DNSOutgoing(_FLAGS_QR_RESPONSE_AA, True) + _add_answers_additionals(out, answers) + return out + + +def construct_outgoing_unicast_answers( + answers: _AnswerWithAdditionalsType, ucast_source: bool, questions: List[DNSQuestion], id_: int_ +) -> DNSOutgoing: + """Add answers and additionals to a DNSOutgoing.""" + out = DNSOutgoing(_FLAGS_QR_RESPONSE_AA, False, id_) + # Adding the questions back when the source is legacy unicast behavior + if ucast_source: + for question in questions: + out.add_question(question) + _add_answers_additionals(out, answers) + return out + + +def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsType) -> None: + # Find additionals and suppress any additionals that are already in answers + sending: Set[DNSRecord] = set(answers) + # Answers are sorted to group names together to increase the chance + # that similar names will end up in the same packet and can reduce the + # overall size of the outgoing response via name compression + for answer in sorted(answers, key=NAME_GETTER): + out.add_answer_at_time(answer, 0) + additionals = answers[answer] + for additional in additionals: + if additional not in sending: + out.add_additional_answer(additional) + sending.add(additional) diff --git a/zeroconf/_handlers/multicast_outgoing_queue.py b/zeroconf/_handlers/multicast_outgoing_queue.py new file mode 100644 index 00000000..23288d18 --- /dev/null +++ b/zeroconf/_handlers/multicast_outgoing_queue.py @@ -0,0 +1,122 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import random +from collections import deque +from typing import TYPE_CHECKING + +from .._utils.time import current_time_millis, millis_to_seconds +from .answers import ( + MULTICAST_DELAY_RANDOM_INTERVAL, + AnswerGroup, + _AnswerWithAdditionalsType, + construct_outgoing_multicast_answers, +) + +RAND_INT = random.randint + +if TYPE_CHECKING: + from .._core import Zeroconf + +_float = float +_int = int + + +class MulticastOutgoingQueue: + """An outgoing queue used to aggregate multicast responses.""" + + __slots__ = ( + "zc", + "queue", + "_multicast_delay_random_min", + "_multicast_delay_random_max", + "_additional_delay", + "_aggregation_delay", + ) + + def __init__(self, zeroconf: 'Zeroconf', additional_delay: _int, max_aggregation_delay: _int) -> None: + self.zc = zeroconf + self.queue: deque[AnswerGroup] = deque() + # Additional delay is used to implement + # Protect the network against excessive packet flooding + # https://datatracker.ietf.org/doc/html/rfc6762#section-14 + self._multicast_delay_random_min = MULTICAST_DELAY_RANDOM_INTERVAL[0] + self._multicast_delay_random_max = MULTICAST_DELAY_RANDOM_INTERVAL[1] + self._additional_delay = additional_delay + self._aggregation_delay = max_aggregation_delay + + def async_add(self, now: _float, answers: _AnswerWithAdditionalsType) -> None: + """Add a group of answers with additionals to the outgoing queue.""" + loop = self.zc.loop + if TYPE_CHECKING: + assert loop is not None + random_int = RAND_INT(self._multicast_delay_random_min, self._multicast_delay_random_max) + random_delay = random_int + self._additional_delay + send_after = now + random_delay + send_before = now + self._aggregation_delay + self._additional_delay + if len(self.queue): + # If we calculate a random delay for the send after time + # that is less than the last group scheduled to go out, + # we instead add the answers to the last group as this + # allows aggregating additional responses + last_group = self.queue[-1] + if send_after <= last_group.send_after: + last_group.answers.update(answers) + return + else: + loop.call_at(loop.time() + millis_to_seconds(random_delay), self.async_ready) + self.queue.append(AnswerGroup(send_after, send_before, answers)) + + def _remove_answers_from_queue(self, answers: _AnswerWithAdditionalsType) -> None: + """Remove a set of answers from the outgoing queue.""" + for pending in self.queue: + for record in answers: + pending.answers.pop(record, None) + + def async_ready(self) -> None: + """Process anything in the queue that is ready.""" + zc = self.zc + loop = zc.loop + if TYPE_CHECKING: + assert loop is not None + now = current_time_millis() + + if len(self.queue) > 1 and self.queue[0].send_before > now: + # There is more than one answer in the queue, + # delay until we have to send it (first answer group reaches send_before) + loop.call_at(loop.time() + millis_to_seconds(self.queue[0].send_before - now), self.async_ready) + return + + answers: _AnswerWithAdditionalsType = {} + # Add all groups that can be sent now + while len(self.queue) and self.queue[0].send_after <= now: + answers.update(self.queue.popleft().answers) + + if len(self.queue): + # If there are still groups in the queue that are not ready to send + # be sure we schedule them to go out later + loop.call_at(loop.time() + millis_to_seconds(self.queue[0].send_after - now), self.async_ready) + + if answers: # pragma: no branch + # If we have the same answer scheduled to go out, remove them + self._remove_answers_from_queue(answers) + zc.async_send(construct_outgoing_multicast_answers(answers)) diff --git a/zeroconf/_handlers/query_handler.py b/zeroconf/_handlers/query_handler.py new file mode 100644 index 00000000..ba9c9e31 --- /dev/null +++ b/zeroconf/_handlers/query_handler.py @@ -0,0 +1,437 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast + +from .._cache import DNSCache, _UniqueRecordsType +from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet +from .._protocol.incoming import DNSIncoming +from .._services.info import ServiceInfo +from .._transport import _WrappedTransport +from .._utils.net import IPVersion +from ..const import ( + _ADDRESS_RECORD_TYPES, + _CLASS_IN, + _DNS_OTHER_TTL, + _MDNS_PORT, + _ONE_SECOND, + _SERVICE_TYPE_ENUMERATION_NAME, + _TYPE_A, + _TYPE_AAAA, + _TYPE_ANY, + _TYPE_NSEC, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) +from .answers import ( + QuestionAnswers, + _AnswerWithAdditionalsType, + construct_outgoing_multicast_answers, + construct_outgoing_unicast_answers, +) + +_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES} + +_EMPTY_SERVICES_LIST: List[ServiceInfo] = [] +_EMPTY_TYPES_LIST: List[str] = [] + +_IPVersion_ALL = IPVersion.All + +_int = int +_str = str + +_ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION = 0 +_ANSWER_STRATEGY_POINTER = 1 +_ANSWER_STRATEGY_ADDRESS = 2 +_ANSWER_STRATEGY_SERVICE = 3 +_ANSWER_STRATEGY_TEXT = 4 + +if TYPE_CHECKING: + from .._core import Zeroconf + + +class _AnswerStrategy: + + __slots__ = ("question", "strategy_type", "types", "services") + + def __init__( + self, + question: DNSQuestion, + strategy_type: _int, + types: List[str], + services: List[ServiceInfo], + ) -> None: + """Create an answer strategy.""" + self.question = question + self.strategy_type = strategy_type + self.types = types + self.services = services + + +class _QueryResponse: + """A pair for unicast and multicast DNSOutgoing responses.""" + + __slots__ = ( + "_is_probe", + "_questions", + "_now", + "_cache", + "_additionals", + "_ucast", + "_mcast_now", + "_mcast_aggregate", + "_mcast_aggregate_last_second", + ) + + def __init__(self, cache: DNSCache, questions: List[DNSQuestion], is_probe: bool, now: float) -> None: + """Build a query response.""" + self._is_probe = is_probe + self._questions = questions + self._now = now + self._cache = cache + self._additionals: _AnswerWithAdditionalsType = {} + self._ucast: Set[DNSRecord] = set() + self._mcast_now: Set[DNSRecord] = set() + self._mcast_aggregate: Set[DNSRecord] = set() + self._mcast_aggregate_last_second: Set[DNSRecord] = set() + + def add_qu_question_response(self, answers: _AnswerWithAdditionalsType) -> None: + """Generate a response to a multicast QU query.""" + for record, additionals in answers.items(): + self._additionals[record] = additionals + if self._is_probe: + self._ucast.add(record) + if not self._has_mcast_within_one_quarter_ttl(record): + self._mcast_now.add(record) + elif not self._is_probe: + self._ucast.add(record) + + def add_ucast_question_response(self, answers: _AnswerWithAdditionalsType) -> None: + """Generate a response to a unicast query.""" + self._additionals.update(answers) + self._ucast.update(answers) + + def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> None: + """Generate a response to a multicast query.""" + self._additionals.update(answers) + for answer in answers: + if self._is_probe: + self._mcast_now.add(answer) + continue + + if self._has_mcast_record_in_last_second(answer): + self._mcast_aggregate_last_second.add(answer) + continue + + if len(self._questions) == 1: + question = self._questions[0] + if question.type in _RESPOND_IMMEDIATE_TYPES: + self._mcast_now.add(answer) + continue + + self._mcast_aggregate.add(answer) + + def answers( + self, + ) -> QuestionAnswers: + """Return answer sets that will be queued.""" + ucast = {r: self._additionals[r] for r in self._ucast} + mcast_now = {r: self._additionals[r] for r in self._mcast_now} + mcast_aggregate = {r: self._additionals[r] for r in self._mcast_aggregate} + mcast_aggregate_last_second = {r: self._additionals[r] for r in self._mcast_aggregate_last_second} + return QuestionAnswers(ucast, mcast_now, mcast_aggregate, mcast_aggregate_last_second) + + def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: + """Check to see if a record has been mcasted recently. + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 + When receiving a question with the unicast-response bit set, a + responder SHOULD usually respond with a unicast packet directed back + to the querier. However, if the responder has not multicast that + record recently (within one quarter of its TTL), then the responder + SHOULD instead multicast the response so as to keep all the peer + caches up to date + """ + if TYPE_CHECKING: + record = cast(_UniqueRecordsType, record) + maybe_entry = self._cache.async_get_unique(record) + return bool(maybe_entry is not None and maybe_entry.is_recent(self._now)) + + def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: + """Check if an answer was seen in the last second. + Protect the network against excessive packet flooding + https://datatracker.ietf.org/doc/html/rfc6762#section-14 + """ + if TYPE_CHECKING: + record = cast(_UniqueRecordsType, record) + maybe_entry = self._cache.async_get_unique(record) + return bool(maybe_entry is not None and self._now - maybe_entry.created < _ONE_SECOND) + + +class QueryHandler: + """Query the ServiceRegistry.""" + + __slots__ = ("zc", "registry", "cache", "question_history", "out_queue", "out_delay_queue") + + def __init__(self, zc: 'Zeroconf') -> None: + """Init the query handler.""" + self.zc = zc + self.registry = zc.registry + self.cache = zc.cache + self.question_history = zc.question_history + self.out_queue = zc.out_queue + self.out_delay_queue = zc.out_delay_queue + + def _add_service_type_enumeration_query_answers( + self, types: List[str], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + ) -> None: + """Provide an answer to a service type enumeration query. + + https://datatracker.ietf.org/doc/html/rfc6763#section-9 + """ + for stype in types: + dns_pointer = DNSPointer( + _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, 0.0 + ) + if not known_answers.suppresses(dns_pointer): + answer_set[dns_pointer] = set() + + def _add_pointer_answers( + self, services: List[ServiceInfo], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + ) -> None: + """Answer PTR/ANY question.""" + for service in services: + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.1. + dns_pointer = service._dns_pointer(None) + if known_answers.suppresses(dns_pointer): + continue + answer_set[dns_pointer] = { + service._dns_service(None), + service._dns_text(None), + *service._get_address_and_nsec_records(None), + } + + def _add_address_answers( + self, + services: List[ServiceInfo], + answer_set: _AnswerWithAdditionalsType, + known_answers: DNSRRSet, + type_: _int, + ) -> None: + """Answer A/AAAA/ANY question.""" + for service in services: + answers: List[DNSAddress] = [] + additionals: Set[DNSRecord] = set() + seen_types: Set[int] = set() + for dns_address in service._dns_addresses(None, _IPVersion_ALL): + seen_types.add(dns_address.type) + if dns_address.type != type_: + additionals.add(dns_address) + elif not known_answers.suppresses(dns_address): + answers.append(dns_address) + missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types + if answers: + if missing_types: + assert service.server is not None, "Service server must be set for NSEC record." + additionals.add(service._dns_nsec(list(missing_types), None)) + for answer in answers: + answer_set[answer] = additionals + elif type_ in missing_types: + assert service.server is not None, "Service server must be set for NSEC record." + answer_set[service._dns_nsec(list(missing_types), None)] = set() + + def _answer_question( + self, + question: DNSQuestion, + strategy_type: _int, + types: List[str], + services: List[ServiceInfo], + known_answers: DNSRRSet, + ) -> _AnswerWithAdditionalsType: + """Answer a question.""" + answer_set: _AnswerWithAdditionalsType = {} + + if strategy_type == _ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION: + self._add_service_type_enumeration_query_answers(types, answer_set, known_answers) + elif strategy_type == _ANSWER_STRATEGY_POINTER: + self._add_pointer_answers(services, answer_set, known_answers) + elif strategy_type == _ANSWER_STRATEGY_ADDRESS: + self._add_address_answers(services, answer_set, known_answers, question.type) + elif strategy_type == _ANSWER_STRATEGY_SERVICE: + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.2. + service = services[0] + dns_service = service._dns_service(None) + if not known_answers.suppresses(dns_service): + answer_set[dns_service] = service._get_address_and_nsec_records(None) + elif strategy_type == _ANSWER_STRATEGY_TEXT: # pragma: no branch + service = services[0] + dns_text = service._dns_text(None) + if not known_answers.suppresses(dns_text): + answer_set[dns_text] = set() + + return answer_set + + def async_response( # pylint: disable=unused-argument + self, msgs: List[DNSIncoming], ucast_source: bool + ) -> Optional[QuestionAnswers]: + """Deal with incoming query packets. Provides a response if possible. + + This function must be run in the event loop as it is not + threadsafe. + """ + strategies: List[_AnswerStrategy] = [] + for msg in msgs: + for question in msg._questions: + strategies.extend(self._get_answer_strategies(question)) + + if not strategies: + # We have no way to answer the question because we have + # nothing in the ServiceRegistry that matches or we do not + # understand the question. + return None + + is_probe = False + msg = msgs[0] + questions = msg._questions + # Only decode known answers if we are not a probe and we have + # at least one answer strategy + answers: List[DNSRecord] = [] + for msg in msgs: + if msg.is_probe(): + is_probe = True + else: + answers.extend(msg.answers()) + + query_res = _QueryResponse(self.cache, questions, is_probe, msg.now) + known_answers = DNSRRSet(answers) + known_answers_set: Optional[Set[DNSRecord]] = None + now = msg.now + for strategy in strategies: + question = strategy.question + is_unicast = question.unique # unique and unicast are the same flag + if not is_unicast: + if known_answers_set is None: # pragma: no branch + known_answers_set = known_answers.lookup_set() + self.question_history.add_question_at_time(question, now, known_answers_set) + answer_set = self._answer_question( + question, strategy.strategy_type, strategy.types, strategy.services, known_answers + ) + if not ucast_source and is_unicast: + query_res.add_qu_question_response(answer_set) + continue + if ucast_source: + query_res.add_ucast_question_response(answer_set) + # We always multicast as well even if its a unicast + # source as long as we haven't done it recently (75% of ttl) + query_res.add_mcast_question_response(answer_set) + + return query_res.answers() + + def _get_answer_strategies( + self, + question: DNSQuestion, + ) -> List[_AnswerStrategy]: + """Collect strategies to answer a question.""" + name = question.name + question_lower_name = name.lower() + type_ = question.type + strategies: List[_AnswerStrategy] = [] + + if type_ == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME: + types = self.registry.async_get_types() + if types: + strategies.append( + _AnswerStrategy( + question, _ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION, types, _EMPTY_SERVICES_LIST + ) + ) + return strategies + + if type_ in (_TYPE_PTR, _TYPE_ANY): + services = self.registry.async_get_infos_type(question_lower_name) + if services: + strategies.append( + _AnswerStrategy(question, _ANSWER_STRATEGY_POINTER, _EMPTY_TYPES_LIST, services) + ) + + if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): + services = self.registry.async_get_infos_server(question_lower_name) + if services: + strategies.append( + _AnswerStrategy(question, _ANSWER_STRATEGY_ADDRESS, _EMPTY_TYPES_LIST, services) + ) + + if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): + service = self.registry.async_get_info_name(question_lower_name) + if service is not None: + if type_ in (_TYPE_SRV, _TYPE_ANY): + strategies.append( + _AnswerStrategy(question, _ANSWER_STRATEGY_SERVICE, _EMPTY_TYPES_LIST, [service]) + ) + if type_ in (_TYPE_TXT, _TYPE_ANY): + strategies.append( + _AnswerStrategy(question, _ANSWER_STRATEGY_TEXT, _EMPTY_TYPES_LIST, [service]) + ) + + return strategies + + def handle_assembled_query( + self, + packets: List[DNSIncoming], + addr: _str, + port: _int, + transport: _WrappedTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]], + ) -> None: + """Respond to a (re)assembled query. + + If the protocol recieved packets with the TC bit set, it will + wait a bit for the rest of the packets and only call + handle_assembled_query once it has a complete set of packets + or the timer expires. If the TC bit is not set, a single + packet will be in packets. + """ + first_packet = packets[0] + ucast_source = port != _MDNS_PORT + question_answers = self.async_response(packets, ucast_source) + if question_answers is None: + return + if question_answers.ucast: + questions = first_packet._questions + id_ = first_packet.id + out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_) + # When sending unicast, only send back the reply + # via the same socket that it was recieved from + # as we know its reachable from that socket + self.zc.async_send(out, addr, port, v6_flow_scope, transport) + if question_answers.mcast_now: + self.zc.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now)) + if question_answers.mcast_aggregate: + self.out_queue.async_add(first_packet.now, question_answers.mcast_aggregate) + if question_answers.mcast_aggregate_last_second: + # https://datatracker.ietf.org/doc/html/rfc6762#section-14 + # If we broadcast it in the last second, we have to delay + # at least a second before we send it again + self.out_delay_queue.async_add(first_packet.now, question_answers.mcast_aggregate_last_second) diff --git a/zeroconf/_handlers/record_manager.py b/zeroconf/_handlers/record_manager.py new file mode 100644 index 00000000..0a0f6c54 --- /dev/null +++ b/zeroconf/_handlers/record_manager.py @@ -0,0 +1,215 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast + +from .._cache import _UniqueRecordsType +from .._dns import DNSQuestion, DNSRecord +from .._logger import log +from .._protocol.incoming import DNSIncoming +from .._record_update import RecordUpdate +from .._updates import RecordUpdateListener +from .._utils.time import current_time_millis +from ..const import _ADDRESS_RECORD_TYPES, _DNS_PTR_MIN_TTL, _TYPE_PTR + +if TYPE_CHECKING: + from .._core import Zeroconf + +_float = float + + +class RecordManager: + """Process records into the cache and notify listeners.""" + + __slots__ = ("zc", "cache", "listeners") + + def __init__(self, zeroconf: 'Zeroconf') -> None: + """Init the record manager.""" + self.zc = zeroconf + self.cache = zeroconf.cache + self.listeners: Set[RecordUpdateListener] = set() + + def async_updates(self, now: _float, records: List[RecordUpdate]) -> None: + """Used to notify listeners of new information that has updated + a record. + + This method must be called before the cache is updated. + + This method will be run in the event loop. + """ + for listener in self.listeners: + listener.async_update_records(self.zc, now, records) + + def async_updates_complete(self, notify: bool) -> None: + """Used to notify listeners of new information that has updated + a record. + + This method must be called after the cache is updated. + + This method will be run in the event loop. + """ + for listener in self.listeners: + listener.async_update_records_complete() + if notify: + self.zc.async_notify_all() + + def async_updates_from_response(self, msg: DNSIncoming) -> None: + """Deal with incoming response packets. All answers + are held in the cache, and listeners are notified. + + This function must be run in the event loop as it is not + threadsafe. + """ + updates: List[RecordUpdate] = [] + address_adds: List[DNSRecord] = [] + other_adds: List[DNSRecord] = [] + removes: Set[DNSRecord] = set() + now = msg.now + unique_types: Set[Tuple[str, int, int]] = set() + cache = self.cache + answers = msg.answers() + + for record in answers: + # Protect zeroconf from records that can cause denial of service. + # + # We enforce a minimum TTL for PTR records to avoid + # ServiceBrowsers generating excessive queries refresh queries. + # Apple uses a 15s minimum TTL, however we do not have the same + # level of rate limit and safe guards so we use 1/4 of the recommended value. + record_type = record.type + record_ttl = record.ttl + if record_ttl and record_type == _TYPE_PTR and record_ttl < _DNS_PTR_MIN_TTL: + log.debug( + "Increasing effective ttl of %s to minimum of %s to protect against excessive refreshes.", + record, + _DNS_PTR_MIN_TTL, + ) + record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL) + + if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 + unique_types.add((record.name, record_type, record.class_)) + + if TYPE_CHECKING: + record = cast(_UniqueRecordsType, record) + + maybe_entry = cache.async_get_unique(record) + if not record.is_expired(now): + if maybe_entry is not None: + maybe_entry.reset_ttl(record) + else: + if record_type in _ADDRESS_RECORD_TYPES: + address_adds.append(record) + else: + other_adds.append(record) + updates.append(RecordUpdate(record, maybe_entry)) + # This is likely a goodbye since the record is + # expired and exists in the cache + elif maybe_entry is not None: + updates.append(RecordUpdate(record, maybe_entry)) + removes.add(record) + + if unique_types: + cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now) + + if updates: + self.async_updates(now, updates) + # The cache adds must be processed AFTER we trigger + # the updates since we compare existing data + # with the new data and updating the cache + # ahead of update_record will cause listeners + # to miss changes + # + # We must process address adds before non-addresses + # otherwise a fetch of ServiceInfo may miss an address + # because it thinks the cache is complete + # + # The cache is processed under the context manager to ensure + # that any ServiceBrowser that is going to call + # zc.get_service_info will see the cached value + # but ONLY after all the record updates have been + # processsed. + new = False + if other_adds or address_adds: + new = cache.async_add_records(address_adds) + if cache.async_add_records(other_adds): + new = True + # Removes are processed last since + # ServiceInfo could generate an un-needed query + # because the data was not yet populated. + if removes: + cache.async_remove_records(removes) + if updates: + self.async_updates_complete(new) + + def async_add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: + """Adds a listener for a given question. The listener will have + its update_record method called when information is available to + answer the question(s). + + This function is not thread-safe and must be called in the eventloop. + """ + if not isinstance(listener, RecordUpdateListener): + log.error( # type: ignore[unreachable] + "listeners passed to async_add_listener must inherit from RecordUpdateListener;" + " In the future this will fail" + ) + + self.listeners.add(listener) + + if question is None: + return + + questions = [question] if isinstance(question, DNSQuestion) else question + self._async_update_matching_records(listener, questions) + + def _async_update_matching_records( + self, listener: RecordUpdateListener, questions: List[DNSQuestion] + ) -> None: + """Calls back any existing entries in the cache that answer the question. + + This function must be run from the event loop. + """ + now = current_time_millis() + records: List[RecordUpdate] = [ + RecordUpdate(record, None) + for question in questions + for record in self.cache.async_entries_with_name(question.name) + if not record.is_expired(now) and question.answered_by(record) + ] + if not records: + return + listener.async_update_records(self.zc, now, records) + listener.async_update_records_complete() + self.zc.async_notify_all() + + def async_remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener. + + This function is not threadsafe and must be called in the eventloop. + """ + try: + self.listeners.remove(listener) + self.zc.async_notify_all() + except ValueError as e: + log.exception('Failed to remove listener: %r', e) diff --git a/zeroconf/_history.py b/zeroconf/_history.py new file mode 100644 index 00000000..db6a394d --- /dev/null +++ b/zeroconf/_history.py @@ -0,0 +1,79 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import Dict, List, Set, Tuple + +from ._dns import DNSQuestion, DNSRecord +from .const import _DUPLICATE_QUESTION_INTERVAL + +# The QuestionHistory is used to implement Duplicate Question Suppression +# https://datatracker.ietf.org/doc/html/rfc6762#section-7.3 + +_float = float + + +class QuestionHistory: + """Remember questions and known answers.""" + + def __init__(self) -> None: + """Init a new QuestionHistory.""" + self._history: Dict[DNSQuestion, Tuple[float, Set[DNSRecord]]] = {} + + def add_question_at_time(self, question: DNSQuestion, now: _float, known_answers: Set[DNSRecord]) -> None: + """Remember a question with known answers.""" + self._history[question] = (now, known_answers) + + def suppresses(self, question: DNSQuestion, now: _float, known_answers: Set[DNSRecord]) -> bool: + """Check to see if a question should be suppressed. + + https://datatracker.ietf.org/doc/html/rfc6762#section-7.3 + When multiple queriers on the network are querying + for the same resource records, there is no need for them to all be + repeatedly asking the same question. + """ + previous_question = self._history.get(question) + # There was not previous question in the history + if not previous_question: + return False + than, previous_known_answers = previous_question + # The last question was older than 999ms + if now - than > _DUPLICATE_QUESTION_INTERVAL: + return False + # The last question has more known answers than + # we knew so we have to ask + if previous_known_answers - known_answers: + return False + return True + + def async_expire(self, now: _float) -> None: + """Expire the history of old questions.""" + removes: List[DNSQuestion] = [] + for question, now_known_answers in self._history.items(): + than, _ = now_known_answers + if now - than > _DUPLICATE_QUESTION_INTERVAL: + removes.append(question) + for question in removes: + del self._history[question] + + def clear(self) -> None: + """Clear the history.""" + self._history.clear() diff --git a/zeroconf/_listener.py b/zeroconf/_listener.py new file mode 100644 index 00000000..0f8a8cac --- /dev/null +++ b/zeroconf/_listener.py @@ -0,0 +1,250 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import logging +import random +from functools import partial +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast + +from ._logger import QuietLogger, log +from ._protocol.incoming import DNSIncoming +from ._transport import _WrappedTransport, make_wrapped_transport +from ._utils.time import current_time_millis, millis_to_seconds +from .const import _DUPLICATE_PACKET_SUPPRESSION_INTERVAL, _MAX_MSG_ABSOLUTE + +if TYPE_CHECKING: + from ._core import Zeroconf + +_TC_DELAY_RANDOM_INTERVAL = (400, 500) + + +_bytes = bytes +_str = str +_int = int +_float = float + +DEBUG_ENABLED = partial(log.isEnabledFor, logging.DEBUG) + + +class AsyncListener: + + """A Listener is used by this module to listen on the multicast + group to which DNS messages are sent, allowing the implementation + to cache information as it arrives. + + It requires registration with an Engine object in order to have + the read() method called when a socket is available for reading.""" + + __slots__ = ( + 'zc', + '_registry', + '_record_manager', + "_query_handler", + 'data', + 'last_time', + 'last_message', + 'transport', + 'sock_description', + '_deferred', + '_timers', + ) + + def __init__(self, zc: 'Zeroconf') -> None: + self.zc = zc + self._registry = zc.registry + self._record_manager = zc.record_manager + self._query_handler = zc.query_handler + self.data: Optional[bytes] = None + self.last_time: float = 0 + self.last_message: Optional[DNSIncoming] = None + self.transport: Optional[_WrappedTransport] = None + self.sock_description: Optional[str] = None + self._deferred: Dict[str, List[DNSIncoming]] = {} + self._timers: Dict[str, asyncio.TimerHandle] = {} + super().__init__() + + def datagram_received( + self, data: _bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]] + ) -> None: + data_len = len(data) + debug = DEBUG_ENABLED() + + if data_len > _MAX_MSG_ABSOLUTE: + # Guard against oversized packets to ensure bad implementations cannot overwhelm + # the system. + if debug: + log.debug( + "Discarding incoming packet with length %s, which is larger " + "than the absolute maximum size of %s", + data_len, + _MAX_MSG_ABSOLUTE, + ) + return + now = current_time_millis() + self._process_datagram_at_time(debug, data_len, now, data, addrs) + + def _process_datagram_at_time( + self, + debug: bool, + data_len: _int, + now: _float, + data: _bytes, + addrs: Union[Tuple[str, int], Tuple[str, int, int, int]], + ) -> None: + if ( + self.data == data + and (now - _DUPLICATE_PACKET_SUPPRESSION_INTERVAL) < self.last_time + and self.last_message is not None + and not self.last_message.has_qu_question() + ): + # Guard against duplicate packets + if debug: + log.debug( + 'Ignoring duplicate message with no unicast questions received from %s [socket %s] (%d bytes) as [%r]', + addrs, + self.sock_description, + data_len, + data, + ) + return + + if len(addrs) == 2: + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () + # https://github.com/python/mypy/issues/1178 + addr, port = addrs # type: ignore + addr_port = addrs + if TYPE_CHECKING: + addr_port = cast(Tuple[str, int], addr_port) + scope = None + else: + # https://github.com/python/mypy/issues/1178 + addr, port, flow, scope = addrs # type: ignore + if debug: # pragma: no branch + log.debug('IPv6 scope_id %d associated to the receiving interface', scope) + v6_flow_scope = (flow, scope) + addr_port = (addr, port) + + msg = DNSIncoming(data, addr_port, scope, now) + self.data = data + self.last_time = now + self.last_message = msg + if msg.valid is True: + if debug: + log.debug( + 'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]', + addr, + port, + self.sock_description, + msg, + data_len, + data, + ) + else: + if debug: + log.debug( + 'Received from %r:%r [socket %s]: (%d bytes) [%r]', + addr, + port, + self.sock_description, + data_len, + data, + ) + return + + if not msg.is_query(): + self._record_manager.async_updates_from_response(msg) + return + + if not self._registry.has_entries: + # If the registry is empty, we have no answers to give. + return + + if TYPE_CHECKING: + assert self.transport is not None + self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope) + + def handle_query_or_defer( + self, + msg: DNSIncoming, + addr: _str, + port: _int, + transport: _WrappedTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]], + ) -> None: + """Deal with incoming query packets. Provides a response if + possible.""" + if not msg.truncated: + self._respond_query(msg, addr, port, transport, v6_flow_scope) + return + + deferred = self._deferred.setdefault(addr, []) + # If we get the same packet we ignore it + for incoming in reversed(deferred): + if incoming.data == msg.data: + return + deferred.append(msg) + delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) + loop = self.zc.loop + assert loop is not None + self._cancel_any_timers_for_addr(addr) + self._timers[addr] = loop.call_at( + loop.time() + delay, self._respond_query, None, addr, port, transport, v6_flow_scope + ) + + def _cancel_any_timers_for_addr(self, addr: _str) -> None: + """Cancel any future truncated packet timers for the address.""" + if addr in self._timers: + self._timers.pop(addr).cancel() + + def _respond_query( + self, + msg: Optional[DNSIncoming], + addr: _str, + port: _int, + transport: _WrappedTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]], + ) -> None: + """Respond to a query and reassemble any truncated deferred packets.""" + self._cancel_any_timers_for_addr(addr) + packets = self._deferred.pop(addr, []) + if msg: + packets.append(msg) + + self._query_handler.handle_assembled_query(packets, addr, port, transport, v6_flow_scope) + + def error_received(self, exc: Exception) -> None: + """Likely socket closed or IPv6.""" + # We preformat the message string with the socket as we want + # log_exception_once to log a warrning message once PER EACH + # different socket in case there are problems with multiple + # sockets + msg_str = f"Error with socket {self.sock_description}): %s" + QuietLogger.log_exception_once(exc, msg_str, exc) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + wrapped_transport = make_wrapped_transport(cast(asyncio.DatagramTransport, transport)) + self.transport = wrapped_transport + self.sock_description = f"{wrapped_transport.fileno} ({wrapped_transport.sock_name})" + + def connection_lost(self, exc: Optional[Exception]) -> None: + """Handle connection lost.""" diff --git a/zeroconf/_logger.py b/zeroconf/_logger.py new file mode 100644 index 00000000..b0e66bc9 --- /dev/null +++ b/zeroconf/_logger.py @@ -0,0 +1,86 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + ) + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import logging +import sys +from typing import Any, Dict, Union, cast + +log = logging.getLogger(__name__.split('.', maxsplit=1)[0]) +log.addHandler(logging.NullHandler()) + + +def set_logger_level_if_unset() -> None: + if log.level == logging.NOTSET: + log.setLevel(logging.WARN) + + +set_logger_level_if_unset() + + +class QuietLogger: + _seen_logs: Dict[str, Union[int, tuple]] = {} + + @classmethod + def log_exception_warning(cls, *logger_data: Any) -> None: + exc_info = sys.exc_info() + exc_str = str(exc_info[1]) + if exc_str not in cls._seen_logs: + # log at warning level the first time this is seen + cls._seen_logs[exc_str] = exc_info + logger = log.warning + else: + logger = log.debug + logger(*(logger_data or ['Exception occurred']), exc_info=True) + + @classmethod + def log_exception_debug(cls, *logger_data: Any) -> None: + log_exc_info = False + exc_info = sys.exc_info() + exc_str = str(exc_info[1]) + if exc_str not in cls._seen_logs: + # log the trace only on the first time + cls._seen_logs[exc_str] = exc_info + log_exc_info = True + log.debug(*(logger_data or ['Exception occurred']), exc_info=log_exc_info) + + @classmethod + def log_warning_once(cls, *args: Any) -> None: + msg_str = args[0] + if msg_str not in cls._seen_logs: + cls._seen_logs[msg_str] = 0 + logger = log.warning + else: + logger = log.debug + cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1 + logger(*args) + + @classmethod + def log_exception_once(cls, exc: Exception, *args: Any) -> None: + msg_str = args[0] + if msg_str not in cls._seen_logs: + cls._seen_logs[msg_str] = 0 + logger = log.warning + else: + logger = log.debug + cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1 + logger(*args, exc_info=exc) diff --git a/zeroconf/_protocol/__init__.py b/zeroconf/_protocol/__init__.py new file mode 100644 index 00000000..2ef4b15b --- /dev/null +++ b/zeroconf/_protocol/__init__.py @@ -0,0 +1,21 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" diff --git a/zeroconf/_protocol/incoming.py b/zeroconf/_protocol/incoming.py new file mode 100644 index 00000000..9e208b63 --- /dev/null +++ b/zeroconf/_protocol/incoming.py @@ -0,0 +1,442 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import struct +import sys +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from .._dns import ( + DNSAddress, + DNSHinfo, + DNSNsec, + DNSPointer, + DNSQuestion, + DNSRecord, + DNSService, + DNSText, +) +from .._exceptions import IncomingDecodeError +from .._logger import log +from .._utils.time import current_time_millis +from ..const import ( + _FLAGS_QR_MASK, + _FLAGS_QR_QUERY, + _FLAGS_QR_RESPONSE, + _FLAGS_TC, + _TYPE_A, + _TYPE_AAAA, + _TYPE_CNAME, + _TYPE_HINFO, + _TYPE_NSEC, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, + _TYPES, +) + +DNS_COMPRESSION_HEADER_LEN = 1 +DNS_COMPRESSION_POINTER_LEN = 2 +MAX_DNS_LABELS = 128 +MAX_NAME_LENGTH = 253 + +DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError) + + +_seen_logs: Dict[str, Union[int, tuple]] = {} +_str = str +_int = int + + +class DNSIncoming: + """Object representation of an incoming DNS packet""" + + __slots__ = ( + "_did_read_others", + 'flags', + 'offset', + 'data', + 'view', + '_data_len', + '_name_cache', + '_questions', + '_answers', + 'id', + '_num_questions', + '_num_answers', + '_num_authorities', + '_num_additionals', + 'valid', + 'now', + 'scope_id', + 'source', + '_has_qu_question', + ) + + def __init__( + self, + data: bytes, + source: Optional[Tuple[str, int]] = None, + scope_id: Optional[int] = None, + now: Optional[float] = None, + ) -> None: + """Constructor from string holding bytes of packet""" + self.flags = 0 + self.offset = 0 + self.data = data + self.view = data + self._data_len = len(data) + self._name_cache: Dict[int, List[str]] = {} + self._questions: List[DNSQuestion] = [] + self._answers: List[DNSRecord] = [] + self.id = 0 + self._num_questions = 0 + self._num_answers = 0 + self._num_authorities = 0 + self._num_additionals = 0 + self.valid = False + self._did_read_others = False + self.now = now or current_time_millis() + self.source = source + self.scope_id = scope_id + self._has_qu_question = False + try: + self._initial_parse() + except DECODE_EXCEPTIONS: + self._log_exception_debug( + 'Received invalid packet from %s at offset %d while unpacking %r', + self.source, + self.offset, + self.data, + ) + + def is_query(self) -> bool: + """Returns true if this is a query.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY + + def is_response(self) -> bool: + """Returns true if this is a response.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE + + def has_qu_question(self) -> bool: + """Returns true if any question is a QU question.""" + return self._has_qu_question + + @property + def truncated(self) -> bool: + """Returns true if this is a truncated.""" + return (self.flags & _FLAGS_TC) == _FLAGS_TC + + @property + def questions(self) -> List[DNSQuestion]: + """Questions in the packet.""" + return self._questions + + @property + def num_questions(self) -> int: + """Number of questions in the packet.""" + return self._num_questions + + @property + def num_answers(self) -> int: + """Number of answers in the packet.""" + return self._num_answers + + @property + def num_authorities(self) -> int: + """Number of authorities in the packet.""" + return self._num_authorities + + @property + def num_additionals(self) -> int: + """Number of additionals in the packet.""" + return self._num_additionals + + def _initial_parse(self) -> None: + """Parse the data needed to initalize the packet object.""" + self._read_header() + self._read_questions() + if not self._num_questions: + self._read_others() + self.valid = True + + @classmethod + def _log_exception_debug(cls, *logger_data: Any) -> None: + log_exc_info = False + exc_info = sys.exc_info() + exc_str = str(exc_info[1]) + if exc_str not in _seen_logs: + # log the trace only on the first time + _seen_logs[exc_str] = exc_info + log_exc_info = True + log.debug(*(logger_data or ['Exception occurred']), exc_info=log_exc_info) + + def answers(self) -> List[DNSRecord]: + """Answers in the packet.""" + if not self._did_read_others: + try: + self._read_others() + except DECODE_EXCEPTIONS: + self._log_exception_debug( + 'Received invalid packet from %s at offset %d while unpacking %r', + self.source, + self.offset, + self.data, + ) + return self._answers + + def is_probe(self) -> bool: + """Returns true if this is a probe.""" + return self._num_authorities > 0 + + def __repr__(self) -> str: + return '' % ', '.join( + [ + 'id=%s' % self.id, + 'flags=%s' % self.flags, + 'truncated=%s' % self.truncated, + 'n_q=%s' % self._num_questions, + 'n_ans=%s' % self._num_answers, + 'n_auth=%s' % self._num_authorities, + 'n_add=%s' % self._num_additionals, + 'questions=%s' % self._questions, + 'answers=%s' % self.answers(), + ] + ) + + def _read_header(self) -> None: + """Reads header portion of packet""" + view = self.view + offset = self.offset + self.offset += 12 + # The header has 6 unsigned shorts in network order + self.id = view[offset] << 8 | view[offset + 1] + self.flags = view[offset + 2] << 8 | view[offset + 3] + self._num_questions = view[offset + 4] << 8 | view[offset + 5] + self._num_answers = view[offset + 6] << 8 | view[offset + 7] + self._num_authorities = view[offset + 8] << 8 | view[offset + 9] + self._num_additionals = view[offset + 10] << 8 | view[offset + 11] + + def _read_questions(self) -> None: + """Reads questions section of packet""" + view = self.view + questions = self._questions + for _ in range(self._num_questions): + name = self._read_name() + offset = self.offset + self.offset += 4 + # The question has 2 unsigned shorts in network order + type_ = view[offset] << 8 | view[offset + 1] + class_ = view[offset + 2] << 8 | view[offset + 3] + question = DNSQuestion(name, type_, class_) + if question.unique: # QU questions use the same bit as unique + self._has_qu_question = True + questions.append(question) + + def _read_character_string(self) -> str: + """Reads a character string from the packet""" + length = self.view[self.offset] + self.offset += 1 + info = self.data[self.offset : self.offset + length].decode('utf-8', 'replace') + self.offset += length + return info + + def _read_string(self, length: _int) -> bytes: + """Reads a string of a given length from the packet""" + info = self.data[self.offset : self.offset + length] + self.offset += length + return info + + def _read_others(self) -> None: + """Reads the answers, authorities and additionals section of the + packet""" + self._did_read_others = True + view = self.view + n = self._num_answers + self._num_authorities + self._num_additionals + for _ in range(n): + domain = self._read_name() + offset = self.offset + self.offset += 10 + # type_, class_ and length are unsigned shorts in network order + # ttl is an unsigned long in network order https://www.rfc-editor.org/errata/eid2130 + type_ = view[offset] << 8 | view[offset + 1] + class_ = view[offset + 2] << 8 | view[offset + 3] + ttl = view[offset + 4] << 24 | view[offset + 5] << 16 | view[offset + 6] << 8 | view[offset + 7] + length = view[offset + 8] << 8 | view[offset + 9] + end = self.offset + length + rec = None + try: + rec = self._read_record(domain, type_, class_, ttl, length) + except DECODE_EXCEPTIONS: + # Skip records that fail to decode if we know the length + # If the packet is really corrupt read_name and the unpack + # above would fail and hit the exception catch in read_others + self.offset = end + log.debug( + 'Unable to parse; skipping record for %s with type %s at offset %d while unpacking %r', + domain, + _TYPES.get(type_, type_), + self.offset, + self.data, + exc_info=True, + ) + if rec is not None: + self._answers.append(rec) + + def _read_record( + self, domain: _str, type_: _int, class_: _int, ttl: _int, length: _int + ) -> Optional[DNSRecord]: + """Read known records types and skip unknown ones.""" + if type_ == _TYPE_A: + return DNSAddress(domain, type_, class_, ttl, self._read_string(4), None, self.now) + if type_ in (_TYPE_CNAME, _TYPE_PTR): + return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now) + if type_ == _TYPE_TXT: + return DNSText(domain, type_, class_, ttl, self._read_string(length), self.now) + if type_ == _TYPE_SRV: + view = self.view + offset = self.offset + self.offset += 6 + # The SRV record has 3 unsigned shorts in network order + priority = view[offset] << 8 | view[offset + 1] + weight = view[offset + 2] << 8 | view[offset + 3] + port = view[offset + 4] << 8 | view[offset + 5] + return DNSService( + domain, + type_, + class_, + ttl, + priority, + weight, + port, + self._read_name(), + self.now, + ) + if type_ == _TYPE_HINFO: + return DNSHinfo( + domain, + type_, + class_, + ttl, + self._read_character_string(), + self._read_character_string(), + self.now, + ) + if type_ == _TYPE_AAAA: + return DNSAddress(domain, type_, class_, ttl, self._read_string(16), self.scope_id, self.now) + if type_ == _TYPE_NSEC: + name_start = self.offset + return DNSNsec( + domain, + type_, + class_, + ttl, + self._read_name(), + self._read_bitmap(name_start + length), + self.now, + ) + # Try to ignore types we don't know about + # Skip the payload for the resource record so the next + # records can be parsed correctly + self.offset += length + return None + + def _read_bitmap(self, end: _int) -> List[int]: + """Reads an NSEC bitmap from the packet.""" + rdtypes = [] + view = self.view + while self.offset < end: + offset = self.offset + offset_plus_one = offset + 1 + offset_plus_two = offset + 2 + window = view[offset] + bitmap_length = view[offset_plus_one] + bitmap_end = offset_plus_two + bitmap_length + for i, byte in enumerate(self.data[offset_plus_two:bitmap_end]): + for bit in range(0, 8): + if byte & (0x80 >> bit): + rdtypes.append(bit + window * 256 + i * 8) + self.offset += 2 + bitmap_length + return rdtypes + + def _read_name(self) -> str: + """Reads a domain name from the packet.""" + labels: List[str] = [] + seen_pointers: Set[int] = set() + original_offset = self.offset + self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers) + self._name_cache[original_offset] = labels + name = ".".join(labels) + "." + if len(name) > MAX_NAME_LENGTH: + raise IncomingDecodeError( + f"DNS name {name} exceeds maximum length of {MAX_NAME_LENGTH} from {self.source}" + ) + return name + + def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: Set[int]) -> int: + # This is a tight loop that is called frequently, small optimizations can make a difference. + view = self.view + while off < self._data_len: + length = view[off] + if length == 0: + return off + DNS_COMPRESSION_HEADER_LEN + + if length < 0x40: + label_idx = off + DNS_COMPRESSION_HEADER_LEN + labels.append(self.data[label_idx : label_idx + length].decode('utf-8', 'replace')) + off += DNS_COMPRESSION_HEADER_LEN + length + continue + + if length < 0xC0: + raise IncomingDecodeError( + f"DNS compression type {length} is unknown at {off} from {self.source}" + ) + + # We have a DNS compression pointer + link_data = view[off + 1] + link = (length & 0x3F) * 256 + link_data + link_py_int = link + if link > self._data_len: + raise IncomingDecodeError( + f"DNS compression pointer at {off} points to {link} beyond packet from {self.source}" + ) + if link == off: + raise IncomingDecodeError( + f"DNS compression pointer at {off} points to itself from {self.source}" + ) + if link_py_int in seen_pointers: + raise IncomingDecodeError( + f"DNS compression pointer at {off} was seen again from {self.source}" + ) + linked_labels = self._name_cache.get(link_py_int) + if not linked_labels: + linked_labels = [] + seen_pointers.add(link_py_int) + self._decode_labels_at_offset(link, linked_labels, seen_pointers) + self._name_cache[link_py_int] = linked_labels + labels.extend(linked_labels) + if len(labels) > MAX_DNS_LABELS: + raise IncomingDecodeError( + f"Maximum dns labels reached while processing pointer at {off} from {self.source}" + ) + return off + DNS_COMPRESSION_POINTER_LEN + + raise IncomingDecodeError(f"Corrupt packet received while decoding name from {self.source}") diff --git a/zeroconf/_protocol/outgoing.py b/zeroconf/_protocol/outgoing.py new file mode 100644 index 00000000..f45c3935 --- /dev/null +++ b/zeroconf/_protocol/outgoing.py @@ -0,0 +1,498 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import enum +import logging +from struct import Struct +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union + +from .._dns import DNSPointer, DNSQuestion, DNSRecord +from .._exceptions import NamePartTooLongException +from .._logger import log +from ..const import ( + _CLASS_UNIQUE, + _DNS_HOST_TTL, + _DNS_OTHER_TTL, + _DNS_PACKET_HEADER_LEN, + _FLAGS_QR_MASK, + _FLAGS_QR_QUERY, + _FLAGS_QR_RESPONSE, + _FLAGS_TC, + _MAX_MSG_ABSOLUTE, + _MAX_MSG_TYPICAL, +) +from .incoming import DNSIncoming + +str_ = str +float_ = float +int_ = int +bytes_ = bytes +DNSQuestion_ = DNSQuestion +DNSRecord_ = DNSRecord + + +PACK_BYTE = Struct('>B').pack +PACK_SHORT = Struct('>H').pack +PACK_LONG = Struct('>L').pack + +SHORT_CACHE_MAX = 128 + +BYTE_TABLE = tuple(PACK_BYTE(i) for i in range(256)) +SHORT_LOOKUP = tuple(PACK_SHORT(i) for i in range(SHORT_CACHE_MAX)) +LONG_LOOKUP = {i: PACK_LONG(i) for i in (_DNS_OTHER_TTL, _DNS_HOST_TTL, 0)} + + +class State(enum.Enum): + init = 0 + finished = 1 + + +STATE_INIT = State.init.value +STATE_FINISHED = State.finished.value + +LOGGING_IS_ENABLED_FOR = log.isEnabledFor +LOGGING_DEBUG = logging.DEBUG + + +class DNSOutgoing: + + """Object representation of an outgoing packet""" + + __slots__ = ( + 'flags', + 'finished', + 'id', + 'multicast', + 'packets_data', + 'names', + 'data', + 'size', + 'allow_long', + 'state', + 'questions', + 'answers', + 'authorities', + 'additionals', + ) + + def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: + self.flags = flags + self.finished = False + self.id = id_ + self.multicast = multicast + self.packets_data: List[bytes] = [] + + # these 3 are per-packet -- see also _reset_for_next_packet() + self.names: Dict[str, int] = {} + self.data: List[bytes] = [] + self.size: int = _DNS_PACKET_HEADER_LEN + self.allow_long: bool = True + + self.state = STATE_INIT + + self.questions: List[DNSQuestion] = [] + self.answers: List[Tuple[DNSRecord, float]] = [] + self.authorities: List[DNSPointer] = [] + self.additionals: List[DNSRecord] = [] + + def is_query(self) -> bool: + """Returns true if this is a query.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY + + def is_response(self) -> bool: + """Returns true if this is a response.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE + + def _reset_for_next_packet(self) -> None: + self.names = {} + self.data = [] + self.size = _DNS_PACKET_HEADER_LEN + self.allow_long = True + + def __repr__(self) -> str: + return '' % ', '.join( + [ + 'multicast=%s' % self.multicast, + 'flags=%s' % self.flags, + 'questions=%s' % self.questions, + 'answers=%s' % self.answers, + 'authorities=%s' % self.authorities, + 'additionals=%s' % self.additionals, + ] + ) + + def add_question(self, record: DNSQuestion) -> None: + """Adds a question""" + self.questions.append(record) + + def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: + """Adds an answer""" + if not record.suppressed_by(inp): + self.add_answer_at_time(record, 0.0) + + def add_answer_at_time(self, record: Optional[DNSRecord], now: float_) -> None: + """Adds an answer if it does not expire by a certain time""" + now_double = now + if record is not None and (now_double == 0 or not record.is_expired(now_double)): + self.answers.append((record, now)) + + def add_authorative_answer(self, record: DNSPointer) -> None: + """Adds an authoritative answer""" + self.authorities.append(record) + + def add_additional_answer(self, record: DNSRecord) -> None: + """Adds an additional answer + + From: RFC 6763, DNS-Based Service Discovery, February 2013 + + 12. DNS Additional Record Generation + + DNS has an efficiency feature whereby a DNS server may place + additional records in the additional section of the DNS message. + These additional records are records that the client did not + explicitly request, but the server has reasonable grounds to expect + that the client might request them shortly, so including them can + save the client from having to issue additional queries. + + This section recommends which additional records SHOULD be generated + to improve network efficiency, for both Unicast and Multicast DNS-SD + responses. + + 12.1. PTR Records + + When including a DNS-SD Service Instance Enumeration or Selective + Instance Enumeration (subtype) PTR record in a response packet, the + server/responder SHOULD include the following additional records: + + o The SRV record(s) named in the PTR rdata. + o The TXT record(s) named in the PTR rdata. + o All address records (type "A" and "AAAA") named in the SRV rdata. + + 12.2. SRV Records + + When including an SRV record in a response packet, the + server/responder SHOULD include the following additional records: + + o All address records (type "A" and "AAAA") named in the SRV rdata. + + """ + self.additionals.append(record) + + def _write_byte(self, value: int_) -> None: + """Writes a single byte to the packet""" + self.data.append(BYTE_TABLE[value]) + self.size += 1 + + def _get_short(self, value: int_) -> bytes: + """Convert an unsigned short to 2 bytes.""" + return SHORT_LOOKUP[value] if value < SHORT_CACHE_MAX else PACK_SHORT(value) + + def _insert_short_at_start(self, value: int_) -> None: + """Inserts an unsigned short at the start of the packet""" + self.data.insert(0, self._get_short(value)) + + def _replace_short(self, index: int_, value: int_) -> None: + """Replaces an unsigned short in a certain position in the packet""" + self.data[index] = self._get_short(value) + + def write_short(self, value: int_) -> None: + """Writes an unsigned short to the packet""" + self.data.append(self._get_short(value)) + self.size += 2 + + def _write_int(self, value: Union[float, int]) -> None: + """Writes an unsigned integer to the packet""" + value_as_int = int(value) + long_bytes = LONG_LOOKUP.get(value_as_int) + if long_bytes is not None: + self.data.append(long_bytes) + else: + self.data.append(PACK_LONG(value_as_int)) + self.size += 4 + + def write_string(self, value: bytes_) -> None: + """Writes a string to the packet""" + if TYPE_CHECKING: + assert isinstance(value, bytes) + self.data.append(value) + self.size += len(value) + + def _write_utf(self, s: str_) -> None: + """Writes a UTF-8 string of a given length to the packet""" + utfstr = s.encode('utf-8') + length = len(utfstr) + if length > 64: + raise NamePartTooLongException + self._write_byte(length) + self.write_string(utfstr) + + def write_character_string(self, value: bytes) -> None: + if TYPE_CHECKING: + assert isinstance(value, bytes) + length = len(value) + if length > 256: + raise NamePartTooLongException + self._write_byte(length) + self.write_string(value) + + def write_name(self, name: str_) -> None: + """ + Write names to packet + + 18.14. Name Compression + + When generating Multicast DNS messages, implementations SHOULD use + name compression wherever possible to compress the names of resource + records, by replacing some or all of the resource record name with a + compact two-byte reference to an appearance of that data somewhere + earlier in the message [RFC1035]. + """ + + # split name into each label + if name.endswith('.'): + name = name[:-1] + + index = self.names.get(name, 0) + if index: + self._write_link_to_name(index) + return + + start_size = self.size + labels = name.split('.') + # Write each new label or a pointer to the existing one in the packet + self.names[name] = start_size + self._write_utf(labels[0]) + + name_length = 0 + for count in range(1, len(labels)): + partial_name = '.'.join(labels[count:]) + index = self.names.get(partial_name, 0) + if index: + self._write_link_to_name(index) + return + if name_length == 0: + name_length = len(name.encode('utf-8')) + self.names[partial_name] = start_size + name_length - len(partial_name.encode('utf-8')) + self._write_utf(labels[count]) + + # this is the end of a name + self._write_byte(0) + + def _write_link_to_name(self, index: int_) -> None: + # If part of the name already exists in the packet, + # create a pointer to it + self._write_byte((index >> 8) | 0xC0) + self._write_byte(index & 0xFF) + + def _write_question(self, question: DNSQuestion_) -> bool: + """Writes a question to the packet""" + start_data_length = len(self.data) + start_size = self.size + self.write_name(question.name) + self.write_short(question.type) + self._write_record_class(question) + return self._check_data_limit_or_rollback(start_data_length, start_size) + + def _write_record_class(self, record: Union[DNSQuestion_, DNSRecord_]) -> None: + """Write out the record class including the unique/unicast (QU) bit.""" + class_ = record.class_ + if record.unique is True and self.multicast: + self.write_short(class_ | _CLASS_UNIQUE) + else: + self.write_short(class_) + + def _write_ttl(self, record: DNSRecord_, now: float_) -> None: + """Write out the record ttl.""" + self._write_int(record.ttl if now == 0 else record.get_remaining_ttl(now)) + + def _write_record(self, record: DNSRecord_, now: float_) -> bool: + """Writes a record (answer, authoritative answer, additional) to + the packet. Returns True on success, or False if we did not + because the packet because the record does not fit.""" + start_data_length = len(self.data) + start_size = self.size + self.write_name(record.name) + self.write_short(record.type) + self._write_record_class(record) + self._write_ttl(record, now) + index = len(self.data) + self.write_short(0) # Will get replaced with the actual size + record.write(self) + # Adjust size for the short we will write before this record + length = 0 + for d in self.data[index + 1 :]: + length += len(d) + # Here we replace the 0 length short we wrote + # before with the actual length + self._replace_short(index, length) + return self._check_data_limit_or_rollback(start_data_length, start_size) + + def _check_data_limit_or_rollback(self, start_data_length: int_, start_size: int_) -> bool: + """Check data limit, if we go over, then rollback and return False.""" + len_limit = _MAX_MSG_ABSOLUTE if self.allow_long else _MAX_MSG_TYPICAL + self.allow_long = False + + if self.size <= len_limit: + return True + + if LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG): # pragma: no branch + log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit) + del self.data[start_data_length:] + self.size = start_size + + start_size_int = start_size + rollback_names = [name for name, idx in self.names.items() if idx >= start_size_int] + for name in rollback_names: + del self.names[name] + return False + + def _write_questions_from_offset(self, questions_offset: int_) -> int: + questions_written = 0 + for question in self.questions[questions_offset:]: + if not self._write_question(question): + break + questions_written += 1 + return questions_written + + def _write_answers_from_offset(self, answer_offset: int_) -> int: + answers_written = 0 + for answer, time_ in self.answers[answer_offset:]: + if not self._write_record(answer, time_): + break + answers_written += 1 + return answers_written + + def _write_records_from_offset(self, records: Sequence[DNSRecord], offset: int_) -> int: + records_written = 0 + for record in records[offset:]: + if not self._write_record(record, 0): + break + records_written += 1 + return records_written + + def _has_more_to_add( + self, questions_offset: int_, answer_offset: int_, authority_offset: int_, additional_offset: int_ + ) -> bool: + """Check if all questions, answers, authority, and additionals have been written to the packet.""" + return ( + questions_offset < len(self.questions) + or answer_offset < len(self.answers) + or authority_offset < len(self.authorities) + or additional_offset < len(self.additionals) + ) + + def packets(self) -> List[bytes]: + """Returns a list of bytestrings containing the packets' bytes + + No further parts should be added to the packet once this + is done. The packets are each restricted to _MAX_MSG_TYPICAL + or less in length, except for the case of a single answer which + will be written out to a single oversized packet no more than + _MAX_MSG_ABSOLUTE in length (and hence will be subject to IP + fragmentation potentially).""" + packets_data = self.packets_data + + if self.state == STATE_FINISHED: + return packets_data + + questions_offset = 0 + answer_offset = 0 + authority_offset = 0 + additional_offset = 0 + # we have to at least write out the question + debug_enable = LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG) is True + has_more_to_add = True + + while has_more_to_add: + if debug_enable: + log.debug( + "offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", + questions_offset, + answer_offset, + authority_offset, + additional_offset, + ) + log.debug( + "lengths = questions=%d, answers=%d, authorities=%d, additionals=%d", + len(self.questions), + len(self.answers), + len(self.authorities), + len(self.additionals), + ) + + questions_written = self._write_questions_from_offset(questions_offset) + answers_written = self._write_answers_from_offset(answer_offset) + authorities_written = self._write_records_from_offset(self.authorities, authority_offset) + additionals_written = self._write_records_from_offset(self.additionals, additional_offset) + + made_progress = bool(self.data) + + self._insert_short_at_start(additionals_written) + self._insert_short_at_start(authorities_written) + self._insert_short_at_start(answers_written) + self._insert_short_at_start(questions_written) + + questions_offset += questions_written + answer_offset += answers_written + authority_offset += authorities_written + additional_offset += additionals_written + if debug_enable: + log.debug( + "now offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", + questions_offset, + answer_offset, + authority_offset, + additional_offset, + ) + + has_more_to_add = self._has_more_to_add( + questions_offset, answer_offset, authority_offset, additional_offset + ) + + if has_more_to_add and self.is_query(): + # https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 + if debug_enable: # pragma: no branch + log.debug("Setting TC flag") + self._insert_short_at_start(self.flags | _FLAGS_TC) + else: + self._insert_short_at_start(self.flags) + + if self.multicast: + self._insert_short_at_start(0) + else: + self._insert_short_at_start(self.id) + + packets_data.append(b''.join(self.data)) + + if not made_progress: + # Generating an empty packet is not a desirable outcome, but currently + # too many internals rely on this behavior. So, we'll just return an + # empty packet and log a warning until this can be refactored at a later + # date. + log.warning("packets() made no progress adding records; returning") + break + + if has_more_to_add: + self._reset_for_next_packet() + + self.state = STATE_FINISHED + return packets_data diff --git a/zeroconf/_record_update.py b/zeroconf/_record_update.py new file mode 100644 index 00000000..8e0e4bdb --- /dev/null +++ b/zeroconf/_record_update.py @@ -0,0 +1,42 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import Optional + +from ._dns import DNSRecord + + +class RecordUpdate: + __slots__ = ("new", "old") + + def __init__(self, new: DNSRecord, old: Optional[DNSRecord] = None): + """RecordUpdate represents a change in a DNS record.""" + self.new = new + self.old = old + + def __getitem__(self, index: int) -> Optional[DNSRecord]: + """Get the new or old record.""" + if index == 0: + return self.new + elif index == 1: + return self.old + raise IndexError(index) diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py new file mode 100644 index 00000000..cf54d7f0 --- /dev/null +++ b/zeroconf/_services/__init__.py @@ -0,0 +1,75 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import enum +from typing import TYPE_CHECKING, Any, Callable, List + +if TYPE_CHECKING: + from .._core import Zeroconf + + +@enum.unique +class ServiceStateChange(enum.Enum): + Added = 1 + Removed = 2 + Updated = 3 + + +class ServiceListener: + def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + def remove_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + +class Signal: + __slots__ = ('_handlers',) + + def __init__(self) -> None: + self._handlers: List[Callable[..., None]] = [] + + def fire(self, **kwargs: Any) -> None: + for h in self._handlers[:]: + h(**kwargs) + + @property + def registration_interface(self) -> 'SignalRegistrationInterface': + return SignalRegistrationInterface(self._handlers) + + +class SignalRegistrationInterface: + __slots__ = ('_handlers',) + + def __init__(self, handlers: List[Callable[..., None]]) -> None: + self._handlers = handlers + + def register_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface': + self._handlers.append(handler) + return self + + def unregister_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface': + self._handlers.remove(handler) + return self diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py new file mode 100644 index 00000000..2ff66074 --- /dev/null +++ b/zeroconf/_services/browser.py @@ -0,0 +1,806 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import heapq +import queue +import random +import threading +import time +import warnings +from functools import partial +from types import TracebackType # noqa # used in type hints +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) + +from .._dns import DNSPointer, DNSQuestion, DNSQuestionType +from .._logger import log +from .._protocol.outgoing import DNSOutgoing +from .._record_update import RecordUpdate +from .._services import ( + ServiceListener, + ServiceStateChange, + Signal, + SignalRegistrationInterface, +) +from .._updates import RecordUpdateListener +from .._utils.name import cached_possible_types, service_type_name +from .._utils.time import current_time_millis, millis_to_seconds +from ..const import ( + _ADDRESS_RECORD_TYPES, + _BROWSER_TIME, + _CLASS_IN, + _DNS_PACKET_HEADER_LEN, + _EXPIRE_REFRESH_TIME_PERCENT, + _FLAGS_QR_QUERY, + _MAX_MSG_TYPICAL, + _MDNS_ADDR, + _MDNS_ADDR6, + _MDNS_PORT, + _TYPE_PTR, +) + +# https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 +_FIRST_QUERY_DELAY_RANDOM_INTERVAL = (20, 120) # ms + +_ON_CHANGE_DISPATCH = { + ServiceStateChange.Added: "add_service", + ServiceStateChange.Removed: "remove_service", + ServiceStateChange.Updated: "update_service", +} + +SERVICE_STATE_CHANGE_ADDED = ServiceStateChange.Added +SERVICE_STATE_CHANGE_REMOVED = ServiceStateChange.Removed +SERVICE_STATE_CHANGE_UPDATED = ServiceStateChange.Updated + +QU_QUESTION = DNSQuestionType.QU + +STARTUP_QUERIES = 4 + +RESCUE_RECORD_RETRY_TTL_PERCENTAGE = 0.1 + +if TYPE_CHECKING: + from .._core import Zeroconf + +float_ = float +int_ = int +bool_ = bool +str_ = str + +_QuestionWithKnownAnswers = Dict[DNSQuestion, Set[DNSPointer]] + +heappop = heapq.heappop +heappush = heapq.heappush + + +class _ScheduledPTRQuery: + + __slots__ = ('alias', 'name', 'ttl', 'cancelled', 'expire_time_millis', 'when_millis') + + def __init__( + self, alias: str, name: str, ttl: int, expire_time_millis: float, when_millis: float + ) -> None: + """Create a scheduled query.""" + self.alias = alias + self.name = name + self.ttl = ttl + # Since queries are stored in a heap we need to track if they are cancelled + # so we can remove them from the heap when they are cancelled as it would + # be too expensive to search the heap for the record to remove and instead + # we just mark it as cancelled and ignore it when we pop it off the heap + # when the query is due. + self.cancelled = False + # Expire time millis is the actual millisecond time the record will expire + self.expire_time_millis = expire_time_millis + # When millis is the millisecond time the query should be sent + # For the first query this is the refresh time which is 75% of the TTL + # + # For subsequent queries we increase the time by 10% of the TTL + # until we reach the expire time and then we stop because it means + # we failed to rescue the record. + self.when_millis = when_millis + + def __repr__(self) -> str: + """Return a string representation of the scheduled query.""" + return ( + f"<{self.__class__.__name__} " + f"alias={self.alias} " + f"name={self.name} " + f"ttl={self.ttl} " + f"cancelled={self.cancelled} " + f"expire_time_millis={self.expire_time_millis} " + f"when_millis={self.when_millis}" + ">" + ) + + def __lt__(self, other: '_ScheduledPTRQuery') -> bool: + """Compare two scheduled queries.""" + if type(other) is _ScheduledPTRQuery: + return self.when_millis < other.when_millis + return NotImplemented + + def __le__(self, other: '_ScheduledPTRQuery') -> bool: + """Compare two scheduled queries.""" + if type(other) is _ScheduledPTRQuery: + return self.when_millis < other.when_millis or self.__eq__(other) + return NotImplemented + + def __eq__(self, other: Any) -> bool: + """Compare two scheduled queries.""" + if type(other) is _ScheduledPTRQuery: + return self.when_millis == other.when_millis + return NotImplemented + + def __ge__(self, other: '_ScheduledPTRQuery') -> bool: + """Compare two scheduled queries.""" + if type(other) is _ScheduledPTRQuery: + return self.when_millis > other.when_millis or self.__eq__(other) + return NotImplemented + + def __gt__(self, other: '_ScheduledPTRQuery') -> bool: + """Compare two scheduled queries.""" + if type(other) is _ScheduledPTRQuery: + return self.when_millis > other.when_millis + return NotImplemented + + +class _DNSPointerOutgoingBucket: + """A DNSOutgoing bucket.""" + + __slots__ = ('now_millis', 'out', 'bytes') + + def __init__(self, now_millis: float, multicast: bool) -> None: + """Create a bucket to wrap a DNSOutgoing.""" + self.now_millis = now_millis + self.out = DNSOutgoing(_FLAGS_QR_QUERY, multicast) + self.bytes = 0 + + def add(self, max_compressed_size: int_, question: DNSQuestion, answers: Set[DNSPointer]) -> None: + """Add a new set of questions and known answers to the outgoing.""" + self.out.add_question(question) + for answer in answers: + self.out.add_answer_at_time(answer, self.now_millis) + self.bytes += max_compressed_size + + +def group_ptr_queries_with_known_answers( + now: float_, multicast: bool_, question_with_known_answers: _QuestionWithKnownAnswers +) -> List[DNSOutgoing]: + """Aggregate queries so that as many known answers as possible fit in the same packet + without having known answers spill over into the next packet unless the + question and known answers are always going to exceed the packet size. + + Some responders do not implement multi-packet known answer suppression + so we try to keep all the known answers in the same packet as the + questions. + """ + return _group_ptr_queries_with_known_answers(now, multicast, question_with_known_answers) + + +def _group_ptr_queries_with_known_answers( + now_millis: float_, multicast: bool_, question_with_known_answers: _QuestionWithKnownAnswers +) -> List[DNSOutgoing]: + """Inner wrapper for group_ptr_queries_with_known_answers.""" + # This is the maximum size the query + known answers can be with name compression. + # The actual size of the query + known answers may be a bit smaller since other + # parts may be shared when the final DNSOutgoing packets are constructed. The + # goal of this algorithm is to quickly bucket the query + known answers without + # the overhead of actually constructing the packets. + query_by_size: Dict[DNSQuestion, int] = { + question: (question.max_size + sum(answer.max_size_compressed for answer in known_answers)) + for question, known_answers in question_with_known_answers.items() + } + max_bucket_size = _MAX_MSG_TYPICAL - _DNS_PACKET_HEADER_LEN + query_buckets: List[_DNSPointerOutgoingBucket] = [] + for question in sorted( + query_by_size, + key=query_by_size.get, # type: ignore + reverse=True, + ): + max_compressed_size = query_by_size[question] + answers = question_with_known_answers[question] + for query_bucket in query_buckets: + if query_bucket.bytes + max_compressed_size <= max_bucket_size: + query_bucket.add(max_compressed_size, question, answers) + break + else: + # If a single question and known answers won't fit in a packet + # we will end up generating multiple packets, but there will never + # be multiple questions + query_bucket = _DNSPointerOutgoingBucket(now_millis, multicast) + query_bucket.add(max_compressed_size, question, answers) + query_buckets.append(query_bucket) + + return [query_bucket.out for query_bucket in query_buckets] + + +def generate_service_query( + zc: 'Zeroconf', + now_millis: float_, + types_: Set[str], + multicast: bool, + question_type: Optional[DNSQuestionType], +) -> List[DNSOutgoing]: + """Generate a service query for sending with zeroconf.send.""" + questions_with_known_answers: _QuestionWithKnownAnswers = {} + qu_question = not multicast if question_type is None else question_type is QU_QUESTION + question_history = zc.question_history + cache = zc.cache + for type_ in types_: + question = DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) + question.unicast = qu_question + known_answers = { + record + for record in cache.get_all_by_details(type_, _TYPE_PTR, _CLASS_IN) + if not record.is_stale(now_millis) + } + if not qu_question and question_history.suppresses(question, now_millis, known_answers): + log.debug("Asking %s was suppressed by the question history", question) + continue + if TYPE_CHECKING: + pointer_known_answers = cast(Set[DNSPointer], known_answers) + else: + pointer_known_answers = known_answers + questions_with_known_answers[question] = pointer_known_answers + if not qu_question: + question_history.add_question_at_time(question, now_millis, known_answers) + + return _group_ptr_queries_with_known_answers(now_millis, multicast, questions_with_known_answers) + + +def _on_change_dispatcher( + listener: ServiceListener, + zeroconf: 'Zeroconf', + service_type: str, + name: str, + state_change: ServiceStateChange, +) -> None: + """Dispatch a service state change to a listener.""" + getattr(listener, _ON_CHANGE_DISPATCH[state_change])(zeroconf, service_type, name) + + +def _service_state_changed_from_listener(listener: ServiceListener) -> Callable[..., None]: + """Generate a service_state_changed handlers from a listener.""" + assert listener is not None + if not hasattr(listener, 'update_service'): + warnings.warn( + "%r has no update_service method. Provide one (it can be empty if you " + "don't care about the updates), it'll become mandatory." % (listener,), + FutureWarning, + ) + return partial(_on_change_dispatcher, listener) + + +class QueryScheduler: + """Schedule outgoing PTR queries for Continuous Multicast DNS Querying + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 + + """ + + __slots__ = ( + '_zc', + '_types', + '_addr', + '_port', + '_multicast', + '_first_random_delay_interval', + '_min_time_between_queries_millis', + '_loop', + '_startup_queries_sent', + '_next_scheduled_for_alias', + '_query_heap', + '_next_run', + '_clock_resolution_millis', + '_question_type', + ) + + def __init__( + self, + zc: "Zeroconf", + types: Set[str], + addr: Optional[str], + port: int, + multicast: bool, + delay: int, + first_random_delay_interval: Tuple[int, int], + question_type: Optional[DNSQuestionType], + ) -> None: + self._zc = zc + self._types = types + self._addr = addr + self._port = port + self._multicast = multicast + self._first_random_delay_interval = first_random_delay_interval + self._min_time_between_queries_millis = delay + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._startup_queries_sent = 0 + self._next_scheduled_for_alias: Dict[str, _ScheduledPTRQuery] = {} + self._query_heap: list[_ScheduledPTRQuery] = [] + self._next_run: Optional[asyncio.TimerHandle] = None + self._clock_resolution_millis = time.get_clock_info('monotonic').resolution * 1000 + self._question_type = question_type + + def start(self, loop: asyncio.AbstractEventLoop) -> None: + """Start the scheduler. + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 + To avoid accidental synchronization when, for some reason, multiple + clients begin querying at exactly the same moment (e.g., because of + some common external trigger event), a Multicast DNS querier SHOULD + also delay the first query of the series by a randomly chosen amount + in the range 20-120 ms. + """ + start_delay = millis_to_seconds(random.randint(*self._first_random_delay_interval)) + self._loop = loop + self._next_run = loop.call_later(start_delay, self._process_startup_queries) + + def stop(self) -> None: + """Stop the scheduler.""" + if self._next_run is not None: + self._next_run.cancel() + self._next_run = None + self._next_scheduled_for_alias.clear() + self._query_heap.clear() + + def _schedule_ptr_refresh( + self, pointer: DNSPointer, expire_time_millis: float_, refresh_time_millis: float_ + ) -> None: + """Schedule a query for a pointer.""" + ttl = int(pointer.ttl) if isinstance(pointer.ttl, float) else pointer.ttl + scheduled_ptr_query = _ScheduledPTRQuery( + pointer.alias, pointer.name, ttl, expire_time_millis, refresh_time_millis + ) + self._schedule_ptr_query(scheduled_ptr_query) + + def _schedule_ptr_query(self, scheduled_query: _ScheduledPTRQuery) -> None: + """Schedule a query for a pointer.""" + self._next_scheduled_for_alias[scheduled_query.alias] = scheduled_query + heappush(self._query_heap, scheduled_query) + + def cancel_ptr_refresh(self, pointer: DNSPointer) -> None: + """Cancel a query for a pointer.""" + scheduled = self._next_scheduled_for_alias.pop(pointer.alias, None) + if scheduled: + scheduled.cancelled = True + + def reschedule_ptr_first_refresh(self, pointer: DNSPointer) -> None: + """Reschedule a query for a pointer.""" + current = self._next_scheduled_for_alias.get(pointer.alias) + refresh_time_millis = pointer.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) + if current is not None: + # If the expire time is within self._min_time_between_queries_millis + # of the current scheduled time avoid churn by not rescheduling + if ( + -self._min_time_between_queries_millis + <= refresh_time_millis - current.when_millis + <= self._min_time_between_queries_millis + ): + return + current.cancelled = True + del self._next_scheduled_for_alias[pointer.alias] + expire_time_millis = pointer.get_expiration_time(100) + self._schedule_ptr_refresh(pointer, expire_time_millis, refresh_time_millis) + + def schedule_rescue_query( + self, query: _ScheduledPTRQuery, now_millis: float_, additional_percentage: float_ + ) -> None: + """Reschedule a query for a pointer at an additional percentage of expiration.""" + ttl_millis = query.ttl * 1000 + additional_wait = ttl_millis * additional_percentage + next_query_time = now_millis + additional_wait + if next_query_time >= query.expire_time_millis: + # If we would schedule past the expire time + # there is no point in scheduling as we already + # tried to rescue the record and failed + return + scheduled_ptr_query = _ScheduledPTRQuery( + query.alias, query.name, query.ttl, query.expire_time_millis, next_query_time + ) + self._schedule_ptr_query(scheduled_ptr_query) + + def _process_startup_queries(self) -> None: + if TYPE_CHECKING: + assert self._loop is not None + # This is a safety to ensure we stop sending queries if Zeroconf instance + # is stopped without the browser being cancelled + if self._zc.done: + return + + now_millis = current_time_millis() + + # At first we will send STARTUP_QUERIES queries to get the cache populated + self.async_send_ready_queries(self._startup_queries_sent == 0, now_millis, self._types) + self._startup_queries_sent += 1 + + # Once we finish sending the initial queries we will + # switch to a strategy of sending queries only when we + # need to refresh records that are about to expire + if self._startup_queries_sent >= STARTUP_QUERIES: + self._next_run = self._loop.call_at( + millis_to_seconds(now_millis + self._min_time_between_queries_millis), + self._process_ready_types, + ) + return + + self._next_run = self._loop.call_later(self._startup_queries_sent**2, self._process_startup_queries) + + def _process_ready_types(self) -> None: + """Generate a list of ready types that is due and schedule the next time.""" + if TYPE_CHECKING: + assert self._loop is not None + # This is a safety to ensure we stop sending queries if Zeroconf instance + # is stopped without the browser being cancelled + if self._zc.done: + return + + now_millis = current_time_millis() + # Refresh records that are about to expire (aka + # _EXPIRE_REFRESH_TIME_PERCENT which is currently 75% of the TTL) and + # additional rescue queries if the 75% query failed to refresh the record + # with a minimum time between queries of _min_time_between_queries + # which defaults to 10s + + ready_types: Set[str] = set() + next_scheduled: Optional[_ScheduledPTRQuery] = None + end_time_millis = now_millis + self._clock_resolution_millis + schedule_rescue: List[_ScheduledPTRQuery] = [] + + while self._query_heap: + query = self._query_heap[0] + if query.cancelled: + heappop(self._query_heap) + continue + if query.when_millis > end_time_millis: + next_scheduled = query + break + query = heappop(self._query_heap) + ready_types.add(query.name) + del self._next_scheduled_for_alias[query.alias] + # If there is still more than 10% of the TTL remaining + # schedule a query again to try to rescue the record + # from expiring. If the record is refreshed before + # the query, the query will get cancelled. + schedule_rescue.append(query) + + for query in schedule_rescue: + self.schedule_rescue_query(query, now_millis, RESCUE_RECORD_RETRY_TTL_PERCENTAGE) + + if ready_types: + self.async_send_ready_queries(False, now_millis, ready_types) + + next_time_millis = now_millis + self._min_time_between_queries_millis + + if next_scheduled is not None and next_scheduled.when_millis > next_time_millis: + next_when_millis = next_scheduled.when_millis + else: + next_when_millis = next_time_millis + + self._next_run = self._loop.call_at(millis_to_seconds(next_when_millis), self._process_ready_types) + + def async_send_ready_queries( + self, first_request: bool, now_millis: float_, ready_types: Set[str] + ) -> None: + """Send any ready queries.""" + # If they did not specify and this is the first request, ask QU questions + # https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 since we are + # just starting up and we know our cache is likely empty. This ensures + # the next outgoing will be sent with the known answers list. + question_type = QU_QUESTION if self._question_type is None and first_request else self._question_type + outs = generate_service_query(self._zc, now_millis, ready_types, self._multicast, question_type) + if outs: + for out in outs: + self._zc.async_send(out, self._addr, self._port) + + +class _ServiceBrowserBase(RecordUpdateListener): + """Base class for ServiceBrowser.""" + + __slots__ = ( + 'types', + 'zc', + '_cache', + '_loop', + '_pending_handlers', + '_service_state_changed', + 'query_scheduler', + 'done', + '_query_sender_task', + ) + + def __init__( + self, + zc: 'Zeroconf', + type_: Union[str, list], + handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[ServiceListener] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + question_type: Optional[DNSQuestionType] = None, + ) -> None: + """Used to browse for a service for specific type(s). + + Constructor parameters are as follows: + + * `zc`: A Zeroconf instance + * `type_`: fully qualified service type name + * `handler`: ServiceListener or Callable that knows how to process ServiceStateChange events + * `listener`: ServiceListener + * `addr`: address to send queries (will default to multicast) + * `port`: port to send queries (will default to mdns 5353) + * `delay`: The initial delay between answering questions + * `question_type`: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU) + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability. + """ + assert handlers or listener, 'You need to specify at least one handler' + self.types: Set[str] = set(type_ if isinstance(type_, list) else [type_]) + for check_type_ in self.types: + # Will generate BadTypeInNameException on a bad name + service_type_name(check_type_, strict=False) + self.zc = zc + self._cache = zc.cache + assert zc.loop is not None + self._loop = zc.loop + self._pending_handlers: Dict[Tuple[str, str], ServiceStateChange] = {} + self._service_state_changed = Signal() + self.query_scheduler = QueryScheduler( + zc, + self.types, + addr, + port, + addr in (None, _MDNS_ADDR, _MDNS_ADDR6), + delay, + _FIRST_QUERY_DELAY_RANDOM_INTERVAL, + question_type, + ) + self.done = False + self._query_sender_task: Optional[asyncio.Task] = None + + if hasattr(handlers, 'add_service'): + listener = cast('ServiceListener', handlers) + handlers = None + + handlers = cast(List[Callable[..., None]], handlers or []) + + if listener: + handlers.append(_service_state_changed_from_listener(listener)) + + for h in handlers: + self.service_state_changed.register_handler(h) + + def _async_start(self) -> None: + """Generate the next time and setup listeners. + + Must be called by uses of this base class after they + have finished setting their properties. + """ + self.zc.async_add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]) + # Only start queries after the listener is installed + self._query_sender_task = asyncio.ensure_future(self._async_start_query_sender()) + + @property + def service_state_changed(self) -> SignalRegistrationInterface: + return self._service_state_changed.registration_interface + + def _names_matching_types(self, names: Iterable[str]) -> List[Tuple[str, str]]: + """Return the type and name for records matching the types we are browsing.""" + return [ + (type_, name) for name in names for type_ in self.types.intersection(cached_possible_types(name)) + ] + + def _enqueue_callback( + self, + state_change: ServiceStateChange, + type_: str_, + name: str_, + ) -> None: + # Code to ensure we only do a single update message + # Precedence is; Added, Remove, Update + key = (name, type_) + if ( + state_change is SERVICE_STATE_CHANGE_ADDED + or ( + state_change is SERVICE_STATE_CHANGE_REMOVED + and self._pending_handlers.get(key) is not SERVICE_STATE_CHANGE_ADDED + ) + or (state_change is SERVICE_STATE_CHANGE_UPDATED and key not in self._pending_handlers) + ): + self._pending_handlers[key] = state_change + + def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[RecordUpdate]) -> None: + """Callback invoked by Zeroconf when new information arrives. + + Updates information required by browser in the Zeroconf cache. + + Ensures that there is are no unnecessary duplicates in the list. + + This method will be run in the event loop. + """ + for record_update in records: + record = record_update.new + old_record = record_update.old + record_type = record.type + + if record_type is _TYPE_PTR: + if TYPE_CHECKING: + record = cast(DNSPointer, record) + pointer = record + for type_ in self.types.intersection(cached_possible_types(pointer.name)): + if old_record is None: + self._enqueue_callback(SERVICE_STATE_CHANGE_ADDED, type_, pointer.alias) + self.query_scheduler.reschedule_ptr_first_refresh(pointer) + elif pointer.is_expired(now): + self._enqueue_callback(SERVICE_STATE_CHANGE_REMOVED, type_, pointer.alias) + self.query_scheduler.cancel_ptr_refresh(pointer) + else: + self.query_scheduler.reschedule_ptr_first_refresh(pointer) + continue + + # If its expired or already exists in the cache it cannot be updated. + if old_record is not None or record.is_expired(now): + continue + + if record_type in _ADDRESS_RECORD_TYPES: + cache = self._cache + names = {service.name for service in cache.async_entries_with_server(record.name)} + # Iterate through the DNSCache and callback any services that use this address + for type_, name in self._names_matching_types(names): + self._enqueue_callback(SERVICE_STATE_CHANGE_UPDATED, type_, name) + continue + + for type_, name in self._names_matching_types((record.name,)): + self._enqueue_callback(SERVICE_STATE_CHANGE_UPDATED, type_, name) + + def async_update_records_complete(self) -> None: + """Called when a record update has completed for all handlers. + + At this point the cache will have the new records. + + This method will be run in the event loop. + + This method is expected to be overridden by subclasses. + """ + for pending in self._pending_handlers.items(): + self._fire_service_state_changed_event(pending) + self._pending_handlers.clear() + + def _fire_service_state_changed_event(self, event: Tuple[Tuple[str, str], ServiceStateChange]) -> None: + """Fire a service state changed event. + + When running with ServiceBrowser, this will happen in the dedicated + thread. + + When running with AsyncServiceBrowser, this will happen in the event loop. + """ + name_type = event[0] + state_change = event[1] + self._service_state_changed.fire( + zeroconf=self.zc, + service_type=name_type[1], + name=name_type[0], + state_change=state_change, + ) + + def _async_cancel(self) -> None: + """Cancel the browser.""" + self.done = True + self.query_scheduler.stop() + self.zc.async_remove_listener(self) + assert self._query_sender_task is not None, "Attempted to cancel a browser that was not started" + self._query_sender_task.cancel() + self._query_sender_task = None + + async def _async_start_query_sender(self) -> None: + """Start scheduling queries.""" + if not self.zc.started: + await self.zc.async_wait_for_start() + self.query_scheduler.start(self._loop) + + +class ServiceBrowser(_ServiceBrowserBase, threading.Thread): + """Used to browse for a service of a specific type. + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability.""" + + def __init__( + self, + zc: 'Zeroconf', + type_: Union[str, list], + handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[ServiceListener] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + question_type: Optional[DNSQuestionType] = None, + ) -> None: + assert zc.loop is not None + if not zc.loop.is_running(): + raise RuntimeError("The event loop is not running") + threading.Thread.__init__(self) + super().__init__(zc, type_, handlers, listener, addr, port, delay, question_type) + # Add the queue before the listener is installed in _setup + # to ensure that events run in the dedicated thread and do + # not block the event loop + self.queue: queue.SimpleQueue = queue.SimpleQueue() + self.daemon = True + self.start() + zc.loop.call_soon_threadsafe(self._async_start) + self.name = "zeroconf-ServiceBrowser-{}-{}".format( + '-'.join([type_[:-7] for type_ in self.types]), + getattr(self, 'native_id', self.ident), + ) + + def cancel(self) -> None: + """Cancel the browser.""" + assert self.zc.loop is not None + self.queue.put(None) + self.zc.loop.call_soon_threadsafe(self._async_cancel) + self.join() + + def run(self) -> None: + """Run the browser thread.""" + while True: + event = self.queue.get() + if event is None: + return + self._fire_service_state_changed_event(event) + + def async_update_records_complete(self) -> None: + """Called when a record update has completed for all handlers. + + At this point the cache will have the new records. + + This method will be run in the event loop. + """ + for pending in self._pending_handlers.items(): + self.queue.put(pending) + self._pending_handlers.clear() + + def __enter__(self) -> 'ServiceBrowser': + return self + + def __exit__( # pylint: disable=useless-return + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self.cancel() + return None diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py new file mode 100644 index 00000000..48ad1140 --- /dev/null +++ b/zeroconf/_services/info.py @@ -0,0 +1,926 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import random +import sys +from ipaddress import IPv4Address, IPv6Address, _BaseAddress +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast + +from .._cache import DNSCache +from .._dns import ( + DNSAddress, + DNSNsec, + DNSPointer, + DNSQuestion, + DNSQuestionType, + DNSRecord, + DNSService, + DNSText, +) +from .._exceptions import BadTypeInNameException +from .._history import QuestionHistory +from .._logger import log +from .._protocol.outgoing import DNSOutgoing +from .._record_update import RecordUpdate +from .._updates import RecordUpdateListener +from .._utils.asyncio import ( + _resolve_all_futures_to_none, + get_running_loop, + run_coro_with_timeout, + wait_for_future_set_or_timeout, +) +from .._utils.ipaddress import ( + cached_ip_addresses, + get_ip_address_object_from_record, + ip_bytes_and_scope_to_address, + str_without_scope_id, +) +from .._utils.name import service_type_name +from .._utils.net import IPVersion, _encode_address +from .._utils.time import current_time_millis +from ..const import ( + _ADDRESS_RECORD_TYPES, + _CLASS_IN, + _CLASS_IN_UNIQUE, + _DNS_HOST_TTL, + _DNS_OTHER_TTL, + _DUPLICATE_QUESTION_INTERVAL, + _FLAGS_QR_QUERY, + _LISTENER_TIME, + _MDNS_PORT, + _TYPE_A, + _TYPE_AAAA, + _TYPE_NSEC, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) + +IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0) + +_IPVersion_All_value = IPVersion.All.value +_IPVersion_V4Only_value = IPVersion.V4Only.value +# https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 +# The most common case for calling ServiceInfo is from a +# ServiceBrowser. After the first request we add a few random +# milliseconds to the delay between requests to reduce the chance +# that there are multiple ServiceBrowser callbacks running on +# the network that are firing at the same time when they +# see the same multicast response and decide to refresh +# the A/AAAA/SRV records for a host. +_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120) + +bytes_ = bytes +float_ = float +int_ = int +str_ = str + +QU_QUESTION = DNSQuestionType.QU +QM_QUESTION = DNSQuestionType.QM + +randint = random.randint + +if TYPE_CHECKING: + from .._core import Zeroconf + + +def instance_name_from_service_info(info: "ServiceInfo", strict: bool = True) -> str: + """Calculate the instance name from the ServiceInfo.""" + # This is kind of funky because of the subtype based tests + # need to make subtypes a first class citizen + service_name = service_type_name(info.name, strict=strict) + if not info.type.endswith(service_name): + raise BadTypeInNameException + return info.name[: -len(service_name) - 1] + + +class ServiceInfo(RecordUpdateListener): + """Service information. + + Constructor parameters are as follows: + + * `type_`: fully qualified service type name + * `name`: fully qualified service name + * `port`: port that the service runs on + * `weight`: weight of the service + * `priority`: priority of the service + * `properties`: dictionary of properties (or a bytes object holding the contents of the `text` field). + converted to str and then encoded to bytes using UTF-8. Keys with `None` values are converted to + value-less attributes. + * `server`: fully qualified name for service host (defaults to name) + * `host_ttl`: ttl used for A/SRV records + * `other_ttl`: ttl used for PTR/TXT records + * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order, + or in parsed form as text; at most one of those parameters can be provided) + * interface_index: scope_id or zone_id for IPv6 link-local addresses i.e. an identifier of the interface + where the peer is connected to + """ + + __slots__ = ( + "text", + "type", + "_name", + "key", + "_ipv4_addresses", + "_ipv6_addresses", + "port", + "weight", + "priority", + "server", + "server_key", + "_properties", + "_decoded_properties", + "host_ttl", + "other_ttl", + "interface_index", + "_new_records_futures", + "_dns_pointer_cache", + "_dns_service_cache", + "_dns_text_cache", + "_dns_address_cache", + "_get_address_and_nsec_records_cache", + ) + + def __init__( + self, + type_: str, + name: str, + port: Optional[int] = None, + weight: int = 0, + priority: int = 0, + properties: Union[bytes, Dict] = b'', + server: Optional[str] = None, + host_ttl: int = _DNS_HOST_TTL, + other_ttl: int = _DNS_OTHER_TTL, + *, + addresses: Optional[List[bytes]] = None, + parsed_addresses: Optional[List[str]] = None, + interface_index: Optional[int] = None, + ) -> None: + # Accept both none, or one, but not both. + if addresses is not None and parsed_addresses is not None: + raise TypeError("addresses and parsed_addresses cannot be provided together") + if not type_.endswith(service_type_name(name, strict=False)): + raise BadTypeInNameException + self.interface_index = interface_index + self.text = b'' + self.type = type_ + self._name = name + self.key = name.lower() + self._ipv4_addresses: List[IPv4Address] = [] + self._ipv6_addresses: List[IPv6Address] = [] + if addresses is not None: + self.addresses = addresses + elif parsed_addresses is not None: + self.addresses = [_encode_address(a) for a in parsed_addresses] + self.port = port + self.weight = weight + self.priority = priority + self.server = server if server else None + self.server_key = server.lower() if server else None + self._properties: Optional[Dict[bytes, Optional[bytes]]] = None + self._decoded_properties: Optional[Dict[str, Optional[str]]] = None + if isinstance(properties, bytes): + self._set_text(properties) + else: + self._set_properties(properties) + self.host_ttl = host_ttl + self.other_ttl = other_ttl + self._new_records_futures: Optional[Set[asyncio.Future]] = None + self._dns_address_cache: Optional[List[DNSAddress]] = None + self._dns_pointer_cache: Optional[DNSPointer] = None + self._dns_service_cache: Optional[DNSService] = None + self._dns_text_cache: Optional[DNSText] = None + self._get_address_and_nsec_records_cache: Optional[Set[DNSRecord]] = None + + @property + def name(self) -> str: + """The name of the service.""" + return self._name + + @name.setter + def name(self, name: str) -> None: + """Replace the the name and reset the key.""" + self._name = name + self.key = name.lower() + self._dns_service_cache = None + self._dns_pointer_cache = None + self._dns_text_cache = None + + @property + def addresses(self) -> List[bytes]: + """IPv4 addresses of this service. + + Only IPv4 addresses are returned for backward compatibility. + Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to + include IPv6 addresses as well. + """ + return self.addresses_by_version(IPVersion.V4Only) + + @addresses.setter + def addresses(self, value: List[bytes]) -> None: + """Replace the addresses list. + + This replaces all currently stored addresses, both IPv4 and IPv6. + """ + self._ipv4_addresses.clear() + self._ipv6_addresses.clear() + self._dns_address_cache = None + self._get_address_and_nsec_records_cache = None + + for address in value: + if IPADDRESS_SUPPORTS_SCOPE_ID and len(address) == 16 and self.interface_index is not None: + addr = ip_bytes_and_scope_to_address(address, self.interface_index) + else: + addr = cached_ip_addresses(address) + if addr is None: + raise TypeError( + "Addresses must either be IPv4 or IPv6 strings, bytes, or integers;" + f" got {address!r}. Hint: convert string addresses with socket.inet_pton" + ) + if addr.version == 4: + if TYPE_CHECKING: + assert isinstance(addr, IPv4Address) + self._ipv4_addresses.append(addr) + else: + if TYPE_CHECKING: + assert isinstance(addr, IPv6Address) + self._ipv6_addresses.append(addr) + + @property + def properties(self) -> Dict[bytes, Optional[bytes]]: + """Return properties as bytes.""" + if self._properties is None: + self._unpack_text_into_properties() + if TYPE_CHECKING: + assert self._properties is not None + return self._properties + + @property + def decoded_properties(self) -> Dict[str, Optional[str]]: + """Return properties as strings.""" + if self._decoded_properties is None: + self._generate_decoded_properties() + if TYPE_CHECKING: + assert self._decoded_properties is not None + return self._decoded_properties + + def async_clear_cache(self) -> None: + """Clear the cache for this service info.""" + self._dns_address_cache = None + self._dns_pointer_cache = None + self._dns_service_cache = None + self._dns_text_cache = None + self._get_address_and_nsec_records_cache = None + + async def async_wait(self, timeout: float, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + """Calling task waits for a given number of milliseconds or until notified.""" + if not self._new_records_futures: + self._new_records_futures = set() + await wait_for_future_set_or_timeout( + loop or asyncio.get_running_loop(), self._new_records_futures, timeout + ) + + def addresses_by_version(self, version: IPVersion) -> List[bytes]: + """List addresses matching IP version. + + Addresses are guaranteed to be returned in LIFO (last in, first out) + order with IPv4 addresses first and IPv6 addresses second. + + This means the first address will always be the most recently added + address of the given IP version. + """ + version_value = version.value + if version_value == _IPVersion_All_value: + ip_v4_packed = [addr.packed for addr in self._ipv4_addresses] + ip_v6_packed = [addr.packed for addr in self._ipv6_addresses] + return [*ip_v4_packed, *ip_v6_packed] + if version_value == _IPVersion_V4Only_value: + return [addr.packed for addr in self._ipv4_addresses] + return [addr.packed for addr in self._ipv6_addresses] + + def ip_addresses_by_version( + self, version: IPVersion + ) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]: + """List ip_address objects matching IP version. + + Addresses are guaranteed to be returned in LIFO (last in, first out) + order with IPv4 addresses first and IPv6 addresses second. + + This means the first address will always be the most recently added + address of the given IP version. + """ + return self._ip_addresses_by_version_value(version.value) + + def _ip_addresses_by_version_value( + self, version_value: int_ + ) -> Union[List[IPv4Address], List[IPv6Address]]: + """Backend for addresses_by_version that uses the raw value.""" + if version_value == _IPVersion_All_value: + return [*self._ipv4_addresses, *self._ipv6_addresses] # type: ignore[return-value] + if version_value == _IPVersion_V4Only_value: + return self._ipv4_addresses + return self._ipv6_addresses + + def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + """List addresses in their parsed string form. + + Addresses are guaranteed to be returned in LIFO (last in, first out) + order with IPv4 addresses first and IPv6 addresses second. + + This means the first address will always be the most recently added + address of the given IP version. + """ + return [str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)] + + def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + """Equivalent to parsed_addresses, with the exception that IPv6 Link-Local + addresses are qualified with % when available + + Addresses are guaranteed to be returned in LIFO (last in, first out) + order with IPv4 addresses first and IPv6 addresses second. + + This means the first address will always be the most recently added + address of the given IP version. + """ + return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)] + + def _set_properties(self, properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]]) -> None: + """Sets properties and text of this info from a dictionary""" + list_: List[bytes] = [] + properties_contain_str = False + result = b'' + for key, value in properties.items(): + if isinstance(key, str): + key = key.encode('utf-8') + properties_contain_str = True + + record = key + if value is not None: + if not isinstance(value, bytes): + value = str(value).encode('utf-8') + properties_contain_str = True + record += b'=' + value + list_.append(record) + for item in list_: + result = b''.join((result, bytes((len(item),)), item)) + if not properties_contain_str: + # If there are no str keys or values, we can use the properties + # as-is, without decoding them, otherwise calling + # self.properties will lazy decode them, which is expensive. + if TYPE_CHECKING: + self._properties = cast("Dict[bytes, Optional[bytes]]", properties) + else: + self._properties = properties + self.text = result + + def _set_text(self, text: bytes) -> None: + """Sets properties and text given a text field""" + if text == self.text: + return + self.text = text + # Clear the properties cache + self._properties = None + self._decoded_properties = None + + def _generate_decoded_properties(self) -> None: + """Generates decoded properties from the properties""" + self._decoded_properties = { + k.decode("ascii", "replace"): None if v is None else v.decode("utf-8", "replace") + for k, v in self.properties.items() + } + + def _unpack_text_into_properties(self) -> None: + """Unpacks the text field into properties""" + text = self.text + end = len(text) + if end == 0: + # Properties should be set atomically + # in case another thread is reading them + self._properties = {} + return + + index = 0 + properties: Dict[bytes, Optional[bytes]] = {} + while index < end: + length = text[index] + index += 1 + key_value = text[index : index + length] + key_sep_value = key_value.partition(b'=') + key = key_sep_value[0] + if key not in properties: + properties[key] = key_sep_value[2] or None + index += length + + self._properties = properties + + def get_name(self) -> str: + """Name accessor""" + return self._name[: len(self._name) - len(self.type) - 1] + + def _get_ip_addresses_from_cache_lifo( + self, zc: 'Zeroconf', now: float_, type: int_ + ) -> List[Union[IPv4Address, IPv6Address]]: + """Set IPv6 addresses from the cache.""" + address_list: List[Union[IPv4Address, IPv6Address]] = [] + for record in self._get_address_records_from_cache_by_type(zc, type): + if record.is_expired(now): + continue + ip_addr = get_ip_address_object_from_record(record) + if ip_addr is not None and ip_addr not in address_list: + address_list.append(ip_addr) + address_list.reverse() # Reverse to get LIFO order + return address_list + + def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None: + """Set IPv6 addresses from the cache.""" + if TYPE_CHECKING: + self._ipv6_addresses = cast( + "List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) + ) + else: + self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) + + def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None: + """Set IPv4 addresses from the cache.""" + if TYPE_CHECKING: + self._ipv4_addresses = cast( + "List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) + ) + else: + self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) + + def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[RecordUpdate]) -> None: + """Updates service information from a DNS record. + + This method will be run in the event loop. + """ + new_records_futures = self._new_records_futures + updated: bool = False + for record_update in records: + updated |= self._process_record_threadsafe(zc, record_update.new, now) + if updated and new_records_futures: + _resolve_all_futures_to_none(new_records_futures) + + def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float_) -> bool: + """Thread safe record updating. + + Returns True if a new record was added. + """ + if record.is_expired(now): + return False + + record_key = record.key + record_type = type(record) + if record_type is DNSAddress and record_key == self.server_key: + dns_address_record = record + if TYPE_CHECKING: + assert isinstance(dns_address_record, DNSAddress) + ip_addr = get_ip_address_object_from_record(dns_address_record) + if ip_addr is None: + log.warning( + "Encountered invalid address while processing %s: %s", + dns_address_record, + dns_address_record.address, + ) + return False + + if ip_addr.version == 4: + if TYPE_CHECKING: + assert isinstance(ip_addr, IPv4Address) + ipv4_addresses = self._ipv4_addresses + if ip_addr not in ipv4_addresses: + ipv4_addresses.insert(0, ip_addr) + return True + elif ip_addr != ipv4_addresses[0]: + ipv4_addresses.remove(ip_addr) + ipv4_addresses.insert(0, ip_addr) + + return False + + if TYPE_CHECKING: + assert isinstance(ip_addr, IPv6Address) + ipv6_addresses = self._ipv6_addresses + if ip_addr not in self._ipv6_addresses: + ipv6_addresses.insert(0, ip_addr) + return True + elif ip_addr != self._ipv6_addresses[0]: + ipv6_addresses.remove(ip_addr) + ipv6_addresses.insert(0, ip_addr) + + return False + + if record_key != self.key: + return False + + if record_type is DNSText: + dns_text_record = record + if TYPE_CHECKING: + assert isinstance(dns_text_record, DNSText) + self._set_text(dns_text_record.text) + return True + + if record_type is DNSService: + dns_service_record = record + if TYPE_CHECKING: + assert isinstance(dns_service_record, DNSService) + old_server_key = self.server_key + self._name = dns_service_record.name + self.key = dns_service_record.key + self.server = dns_service_record.server + self.server_key = dns_service_record.server_key + self.port = dns_service_record.port + self.weight = dns_service_record.weight + self.priority = dns_service_record.priority + if old_server_key != self.server_key: + self._set_ipv4_addresses_from_cache(zc, now) + self._set_ipv6_addresses_from_cache(zc, now) + return True + + return False + + def dns_addresses( + self, + override_ttl: Optional[int] = None, + version: IPVersion = IPVersion.All, + ) -> List[DNSAddress]: + """Return matching DNSAddress from ServiceInfo.""" + return self._dns_addresses(override_ttl, version) + + def _dns_addresses( + self, + override_ttl: Optional[int], + version: IPVersion, + ) -> List[DNSAddress]: + """Return matching DNSAddress from ServiceInfo.""" + cacheable = version is IPVersion.All and override_ttl is None + if self._dns_address_cache is not None and cacheable: + return self._dns_address_cache + name = self.server or self._name + ttl = override_ttl if override_ttl is not None else self.host_ttl + class_ = _CLASS_IN_UNIQUE + version_value = version.value + records = [ + DNSAddress( + name, + _TYPE_AAAA if ip_addr.version == 6 else _TYPE_A, + class_, + ttl, + ip_addr.packed, + created=0.0, + ) + for ip_addr in self._ip_addresses_by_version_value(version_value) + ] + if cacheable: + self._dns_address_cache = records + return records + + def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: + """Return DNSPointer from ServiceInfo.""" + return self._dns_pointer(override_ttl) + + def _dns_pointer(self, override_ttl: Optional[int]) -> DNSPointer: + """Return DNSPointer from ServiceInfo.""" + cacheable = override_ttl is None + if self._dns_pointer_cache is not None and cacheable: + return self._dns_pointer_cache + record = DNSPointer( + self.type, + _TYPE_PTR, + _CLASS_IN, + override_ttl if override_ttl is not None else self.other_ttl, + self._name, + 0.0, + ) + if cacheable: + self._dns_pointer_cache = record + return record + + def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: + """Return DNSService from ServiceInfo.""" + return self._dns_service(override_ttl) + + def _dns_service(self, override_ttl: Optional[int]) -> DNSService: + """Return DNSService from ServiceInfo.""" + cacheable = override_ttl is None + if self._dns_service_cache is not None and cacheable: + return self._dns_service_cache + port = self.port + if TYPE_CHECKING: + assert isinstance(port, int) + record = DNSService( + self._name, + _TYPE_SRV, + _CLASS_IN_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + self.priority, + self.weight, + port, + self.server or self._name, + 0.0, + ) + if cacheable: + self._dns_service_cache = record + return record + + def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: + """Return DNSText from ServiceInfo.""" + return self._dns_text(override_ttl) + + def _dns_text(self, override_ttl: Optional[int]) -> DNSText: + """Return DNSText from ServiceInfo.""" + cacheable = override_ttl is None + if self._dns_text_cache is not None and cacheable: + return self._dns_text_cache + record = DNSText( + self._name, + _TYPE_TXT, + _CLASS_IN_UNIQUE, + override_ttl if override_ttl is not None else self.other_ttl, + self.text, + 0.0, + ) + if cacheable: + self._dns_text_cache = record + return record + + def dns_nsec(self, missing_types: List[int], override_ttl: Optional[int] = None) -> DNSNsec: + """Return DNSNsec from ServiceInfo.""" + return self._dns_nsec(missing_types, override_ttl) + + def _dns_nsec(self, missing_types: List[int], override_ttl: Optional[int]) -> DNSNsec: + """Return DNSNsec from ServiceInfo.""" + return DNSNsec( + self._name, + _TYPE_NSEC, + _CLASS_IN_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + self._name, + missing_types, + 0.0, + ) + + def get_address_and_nsec_records(self, override_ttl: Optional[int] = None) -> Set[DNSRecord]: + """Build a set of address records and NSEC records for non-present record types.""" + return self._get_address_and_nsec_records(override_ttl) + + def _get_address_and_nsec_records(self, override_ttl: Optional[int]) -> Set[DNSRecord]: + """Build a set of address records and NSEC records for non-present record types.""" + cacheable = override_ttl is None + if self._get_address_and_nsec_records_cache is not None and cacheable: + return self._get_address_and_nsec_records_cache + missing_types: Set[int] = _ADDRESS_RECORD_TYPES.copy() + records: Set[DNSRecord] = set() + for dns_address in self._dns_addresses(override_ttl, IPVersion.All): + missing_types.discard(dns_address.type) + records.add(dns_address) + if missing_types: + assert self.server is not None, "Service server must be set for NSEC record." + records.add(self._dns_nsec(list(missing_types), override_ttl)) + if cacheable: + self._get_address_and_nsec_records_cache = records + return records + + def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int_) -> List[DNSAddress]: + """Get the addresses from the cache.""" + if self.server_key is None: + return [] + cache = zc.cache + if TYPE_CHECKING: + records = cast("List[DNSAddress]", cache.get_all_by_details(self.server_key, _type, _CLASS_IN)) + else: + records = cache.get_all_by_details(self.server_key, _type, _CLASS_IN) + return records + + def set_server_if_missing(self) -> None: + """Set the server if it is missing. + + This function is for backwards compatibility. + """ + if self.server is None: + self.server = self._name + self.server_key = self.key + + def load_from_cache(self, zc: 'Zeroconf', now: Optional[float_] = None) -> bool: + """Populate the service info from the cache. + + This method is designed to be threadsafe. + """ + return self._load_from_cache(zc, now or current_time_millis()) + + def _load_from_cache(self, zc: 'Zeroconf', now: float_) -> bool: + """Populate the service info from the cache. + + This method is designed to be threadsafe. + """ + cache = zc.cache + original_server_key = self.server_key + cached_srv_record = cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN) + if cached_srv_record: + self._process_record_threadsafe(zc, cached_srv_record, now) + cached_txt_record = cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN) + if cached_txt_record: + self._process_record_threadsafe(zc, cached_txt_record, now) + if original_server_key == self.server_key: + # If there is a srv which changes the server_key, + # A and AAAA will already be loaded from the cache + # and we do not want to do it twice + for record in self._get_address_records_from_cache_by_type(zc, _TYPE_A): + self._process_record_threadsafe(zc, record, now) + for record in self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA): + self._process_record_threadsafe(zc, record, now) + return self._is_complete + + @property + def _is_complete(self) -> bool: + """The ServiceInfo has all expected properties.""" + return bool(self.text is not None and (self._ipv4_addresses or self._ipv6_addresses)) + + def request( + self, + zc: 'Zeroconf', + timeout: float, + question_type: Optional[DNSQuestionType] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + ) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + + While it is not expected during normal operation, + this function may raise EventLoopBlocked if the underlying + call to `async_request` cannot be completed. + """ + assert zc.loop is not None and zc.loop.is_running() + if zc.loop == get_running_loop(): + raise RuntimeError("Use AsyncServiceInfo.async_request from the event loop") + return bool( + run_coro_with_timeout( + self.async_request(zc, timeout, question_type, addr, port), zc.loop, timeout + ) + ) + + def _get_initial_delay(self) -> float_: + return _LISTENER_TIME + + def _get_random_delay(self) -> int_: + return randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL) + + async def async_request( + self, + zc: 'Zeroconf', + timeout: float, + question_type: Optional[DNSQuestionType] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + ) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + + This method will be run in the event loop. + + Passing addr and port is optional, and will default to the + mDNS multicast address and port. This is useful for directing + requests to a specific host that may be able to respond across + subnets. + """ + if not zc.started: + await zc.async_wait_for_start() + + now = current_time_millis() + + if self._load_from_cache(zc, now): + return True + + if TYPE_CHECKING: + assert zc.loop is not None + + first_request = True + delay = self._get_initial_delay() + next_ = now + last = now + timeout + try: + zc.async_add_listener(self, None) + while not self._is_complete: + if last <= now: + return False + if next_ <= now: + this_question_type = question_type or QU_QUESTION if first_request else QM_QUESTION + out = self._generate_request_query(zc, now, this_question_type) + first_request = False + if out.questions: + # All questions may have been suppressed + # by the question history, so nothing to send, + # but keep waiting for answers in case another + # client on the network is asking the same + # question or they have not arrived yet. + zc.async_send(out, addr, port) + next_ = now + delay + next_ += self._get_random_delay() + if this_question_type is QM_QUESTION and delay < _DUPLICATE_QUESTION_INTERVAL: + # If we just asked a QM question, we need to + # wait at least the duplicate question interval + # before asking another QM question otherwise + # its likely to be suppressed by the question + # history of the remote responder. + delay = _DUPLICATE_QUESTION_INTERVAL + + await self.async_wait(min(next_, last) - now, zc.loop) + now = current_time_millis() + finally: + zc.async_remove_listener(self) + + return True + + def _add_question_with_known_answers( + self, + out: DNSOutgoing, + qu_question: bool, + question_history: QuestionHistory, + cache: DNSCache, + now: float_, + name: str_, + type_: int_, + class_: int_, + skip_if_known_answers: bool, + ) -> None: + """Add a question with known answers if its not suppressed.""" + known_answers = { + answer for answer in cache.get_all_by_details(name, type_, class_) if not answer.is_stale(now) + } + if skip_if_known_answers and known_answers: + return + question = DNSQuestion(name, type_, class_) + if qu_question: + question.unicast = True + elif question_history.suppresses(question, now, known_answers): + return + else: + question_history.add_question_at_time(question, now, known_answers) + out.add_question(question) + for answer in known_answers: + out.add_answer_at_time(answer, now) + + def _generate_request_query( + self, zc: 'Zeroconf', now: float_, question_type: DNSQuestionType + ) -> DNSOutgoing: + """Generate the request query.""" + out = DNSOutgoing(_FLAGS_QR_QUERY) + name = self._name + server = self.server or name + cache = zc.cache + history = zc.question_history + qu_question = question_type is QU_QUESTION + self._add_question_with_known_answers( + out, qu_question, history, cache, now, name, _TYPE_SRV, _CLASS_IN, True + ) + self._add_question_with_known_answers( + out, qu_question, history, cache, now, name, _TYPE_TXT, _CLASS_IN, True + ) + self._add_question_with_known_answers( + out, qu_question, history, cache, now, server, _TYPE_A, _CLASS_IN, False + ) + self._add_question_with_known_answers( + out, qu_question, history, cache, now, server, _TYPE_AAAA, _CLASS_IN, False + ) + return out + + def __repr__(self) -> str: + """String representation""" + return '{}({})'.format( + type(self).__name__, + ', '.join( + f'{name}={getattr(self, name)!r}' + for name in ( + 'type', + 'name', + 'addresses', + 'port', + 'weight', + 'priority', + 'server', + 'properties', + 'interface_index', + ) + ), + ) diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py new file mode 100644 index 00000000..261e8e9c --- /dev/null +++ b/zeroconf/_services/registry.py @@ -0,0 +1,112 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import Dict, List, Optional, Union + +from .._exceptions import ServiceNameAlreadyRegistered +from .info import ServiceInfo + +_str = str + + +class ServiceRegistry: + """A registry to keep track of services. + + The registry must only be accessed from + the event loop as it is not thread safe. + """ + + __slots__ = ("_services", "types", "servers", "has_entries") + + def __init__( + self, + ) -> None: + """Create the ServiceRegistry class.""" + self._services: Dict[str, ServiceInfo] = {} + self.types: Dict[str, List] = {} + self.servers: Dict[str, List] = {} + self.has_entries: bool = False + + def async_add(self, info: ServiceInfo) -> None: + """Add a new service to the registry.""" + self._add(info) + + def async_remove(self, info: Union[List[ServiceInfo], ServiceInfo]) -> None: + """Remove a new service from the registry.""" + self._remove(info if isinstance(info, list) else [info]) + + def async_update(self, info: ServiceInfo) -> None: + """Update new service in the registry.""" + self._remove([info]) + self._add(info) + + def async_get_service_infos(self) -> List[ServiceInfo]: + """Return all ServiceInfo.""" + return list(self._services.values()) + + def async_get_info_name(self, name: str) -> Optional[ServiceInfo]: + """Return all ServiceInfo for the name.""" + return self._services.get(name) + + def async_get_types(self) -> List[str]: + """Return all types.""" + return list(self.types) + + def async_get_infos_type(self, type_: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching type.""" + return self._async_get_by_index(self.types, type_) + + def async_get_infos_server(self, server: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching server.""" + return self._async_get_by_index(self.servers, server) + + def _async_get_by_index(self, records: Dict[str, List], key: _str) -> List[ServiceInfo]: + """Return all ServiceInfo matching the index.""" + record_list = records.get(key) + if record_list is None: + return [] + return [self._services[name] for name in record_list] + + def _add(self, info: ServiceInfo) -> None: + """Add a new service under the lock.""" + assert info.server_key is not None, "ServiceInfo must have a server" + if info.key in self._services: + raise ServiceNameAlreadyRegistered + + info.async_clear_cache() + self._services[info.key] = info + self.types.setdefault(info.type.lower(), []).append(info.key) + self.servers.setdefault(info.server_key, []).append(info.key) + self.has_entries = True + + def _remove(self, infos: List[ServiceInfo]) -> None: + """Remove a services under the lock.""" + for info in infos: + old_service_info = self._services.get(info.key) + if old_service_info is None: + continue + assert old_service_info.server_key is not None + self.types[old_service_info.type.lower()].remove(info.key) + self.servers[old_service_info.server_key].remove(info.key) + del self._services[info.key] + + self.has_entries = bool(self._services) diff --git a/zeroconf/_services/types.py b/zeroconf/_services/types.py new file mode 100644 index 00000000..70db2d60 --- /dev/null +++ b/zeroconf/_services/types.py @@ -0,0 +1,83 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import time +from typing import Optional, Set, Tuple, Union + +from .._core import Zeroconf +from .._services import ServiceListener +from .._utils.net import InterfaceChoice, InterfacesType, IPVersion +from ..const import _SERVICE_TYPE_ENUMERATION_NAME +from .browser import ServiceBrowser + + +class ZeroconfServiceTypes(ServiceListener): + """ + Return all of the advertised services on any local networks + """ + + def __init__(self) -> None: + """Keep track of found services in a set.""" + self.found_services: Set[str] = set() + + def add_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Service added.""" + self.found_services.add(name) + + def update_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Service updated.""" + + def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Service removed.""" + + @classmethod + def find( + cls, + zc: Optional[Zeroconf] = None, + timeout: Union[int, float] = 5, + interfaces: InterfacesType = InterfaceChoice.All, + ip_version: Optional[IPVersion] = None, + ) -> Tuple[str, ...]: + """ + Return all of the advertised services on any local networks. + + :param zc: Zeroconf() instance. Pass in if already have an + instance running or if non-default interfaces are needed + :param timeout: seconds to wait for any responses + :param interfaces: interfaces to listen on. + :param ip_version: IP protocol version to use. + :return: tuple of service type strings + """ + local_zc = zc or Zeroconf(interfaces=interfaces, ip_version=ip_version) + listener = cls() + browser = ServiceBrowser(local_zc, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener) + + # wait for responses + time.sleep(timeout) + + browser.cancel() + + # close down anything we opened + if zc is None: + local_zc.close() + + return tuple(sorted(listener.found_services)) diff --git a/zeroconf/_transport.py b/zeroconf/_transport.py new file mode 100644 index 00000000..c37af2ef --- /dev/null +++ b/zeroconf/_transport.py @@ -0,0 +1,67 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import socket +from typing import Tuple + + +class _WrappedTransport: + """A wrapper for transports.""" + + __slots__ = ( + 'transport', + 'is_ipv6', + 'sock', + 'fileno', + 'sock_name', + ) + + def __init__( + self, + transport: asyncio.DatagramTransport, + is_ipv6: bool, + sock: socket.socket, + fileno: int, + sock_name: Tuple, + ) -> None: + """Initialize the wrapped transport. + + These attributes are used when sending packets. + """ + self.transport = transport + self.is_ipv6 = is_ipv6 + self.sock = sock + self.fileno = fileno + self.sock_name = sock_name + + +def make_wrapped_transport(transport: asyncio.DatagramTransport) -> _WrappedTransport: + """Make a wrapped transport.""" + sock: socket.socket = transport.get_extra_info('socket') + return _WrappedTransport( + transport=transport, + is_ipv6=sock.family == socket.AF_INET6, + sock=sock, + fileno=sock.fileno(), + sock_name=sock.getsockname(), + ) diff --git a/zeroconf/_updates.py b/zeroconf/_updates.py new file mode 100644 index 00000000..42fa8285 --- /dev/null +++ b/zeroconf/_updates.py @@ -0,0 +1,79 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import TYPE_CHECKING, List + +from ._dns import DNSRecord +from ._record_update import RecordUpdate + +if TYPE_CHECKING: + from ._core import Zeroconf + + +float_ = float + + +class RecordUpdateListener: + """Base call for all record listeners. + + All listeners passed to async_add_listener should use RecordUpdateListener + as a base class. In the future it will be required. + """ + + def update_record( # pylint: disable=no-self-use + self, zc: 'Zeroconf', now: float, record: DNSRecord + ) -> None: + """Update a single record. + + This method is deprecated and will be removed in a future version. + update_records should be implemented instead. + """ + raise RuntimeError("update_record is deprecated and will be removed in a future version.") + + def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[RecordUpdate]) -> None: + """Update multiple records in one shot. + + All records that are received in a single packet are passed + to update_records. + + This implementation is a compatibility shim to ensure older code + that uses RecordUpdateListener as a base class will continue to + get calls to update_record. This method will raise + NotImplementedError in a future version. + + At this point the cache will not have the new records + + Records are passed as a list of RecordUpdate. This + allows consumers of async_update_records to avoid cache lookups. + + This method will be run in the event loop. + """ + for record in records: + self.update_record(zc, now, record.new) + + def async_update_records_complete(self) -> None: + """Called when a record update has completed for all handlers. + + At this point the cache will have the new records. + + This method will be run in the event loop. + """ diff --git a/zeroconf/_utils/__init__.py b/zeroconf/_utils/__init__.py new file mode 100644 index 00000000..2ef4b15b --- /dev/null +++ b/zeroconf/_utils/__init__.py @@ -0,0 +1,21 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" diff --git a/zeroconf/_utils/asyncio.py b/zeroconf/_utils/asyncio.py new file mode 100644 index 00000000..c57b4d36 --- /dev/null +++ b/zeroconf/_utils/asyncio.py @@ -0,0 +1,137 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import concurrent.futures +import contextlib +import sys +from typing import Any, Awaitable, Coroutine, Optional, Set + +if sys.version_info[:2] < (3, 11): + from dmx.async_timeout import timeout as asyncio_timeout +else: + from asyncio import timeout as asyncio_timeout + +from .._exceptions import EventLoopBlocked +from ..const import _LOADED_SYSTEM_TIMEOUT +from .time import millis_to_seconds + +# The combined timeouts should be lower than _CLOSE_TIMEOUT + _WAIT_FOR_LOOP_TASKS_TIMEOUT +_TASK_AWAIT_TIMEOUT = 1 +_GET_ALL_TASKS_TIMEOUT = 3 +_WAIT_FOR_LOOP_TASKS_TIMEOUT = 3 # Must be larger than _TASK_AWAIT_TIMEOUT + + +def _set_future_none_if_not_done(fut: asyncio.Future) -> None: + """Set a future to None if it is not done.""" + if not fut.done(): # pragma: no branch + fut.set_result(None) + + +def _resolve_all_futures_to_none(futures: Set[asyncio.Future]) -> None: + """Resolve all futures to None.""" + for fut in futures: + _set_future_none_if_not_done(fut) + futures.clear() + + +async def wait_for_future_set_or_timeout( + loop: asyncio.AbstractEventLoop, future_set: Set[asyncio.Future], timeout: float +) -> None: + """Wait for a future or timeout (in milliseconds).""" + future = loop.create_future() + future_set.add(future) + handle = loop.call_later(millis_to_seconds(timeout), _set_future_none_if_not_done, future) + try: + await future + finally: + handle.cancel() + future_set.discard(future) + + +async def wait_event_or_timeout(event: asyncio.Event, timeout: float) -> None: + """Wait for an event or timeout.""" + with contextlib.suppress(asyncio.TimeoutError): + async with asyncio_timeout(timeout): + await event.wait() + + +async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> Set[asyncio.Task]: + """Return all tasks running.""" + await asyncio.sleep(0) # flush out any call_soon_threadsafe + # If there are multiple event loops running, all_tasks is not + # safe EVEN WHEN CALLED FROM THE EVENTLOOP + # under PyPy so we have to try a few times. + for _ in range(3): + with contextlib.suppress(RuntimeError): + return asyncio.all_tasks(loop) + return set() + + +async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: + """Wait for the event loop thread we started to shutdown.""" + await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT) + + +async def await_awaitable(aw: Awaitable) -> None: + """Wait on an awaitable and the task it returns.""" + task = await aw + await task + + +def run_coro_with_timeout(aw: Coroutine, loop: asyncio.AbstractEventLoop, timeout: float) -> Any: + """Run a coroutine with a timeout. + + The timeout should only be used as a safeguard to prevent + the program from blocking forever. The timeout should + never be expected to be reached during normal operation. + + While not expected during normal operations, the + function raises `EventLoopBlocked` if the coroutine takes + longer to complete than the timeout. + """ + try: + return asyncio.run_coroutine_threadsafe(aw, loop).result( + millis_to_seconds(timeout) + _LOADED_SYSTEM_TIMEOUT + ) + except concurrent.futures.TimeoutError as ex: + raise EventLoopBlocked from ex + + +def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: + """Wait for pending tasks and stop an event loop.""" + pending_tasks = set( + asyncio.run_coroutine_threadsafe(_async_get_all_tasks(loop), loop).result(_GET_ALL_TASKS_TIMEOUT) + ) + pending_tasks -= {task for task in pending_tasks if task.done()} + if pending_tasks: + asyncio.run_coroutine_threadsafe(_wait_for_loop_tasks(pending_tasks), loop).result( + _WAIT_FOR_LOOP_TASKS_TIMEOUT + ) + loop.call_soon_threadsafe(loop.stop) + + +def get_running_loop() -> Optional[asyncio.AbstractEventLoop]: + """Check if an event loop is already running.""" + with contextlib.suppress(RuntimeError): + return asyncio.get_running_loop() + return None diff --git a/zeroconf/_utils/ipaddress.py b/zeroconf/_utils/ipaddress.py new file mode 100644 index 00000000..b0b551ff --- /dev/null +++ b/zeroconf/_utils/ipaddress.py @@ -0,0 +1,134 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" +import sys +from functools import lru_cache +from ipaddress import AddressValueError, IPv4Address, IPv6Address, NetmaskValueError +from typing import Any, Optional, Union + +from .._dns import DNSAddress +from ..const import _TYPE_AAAA + +bytes_ = bytes +int_ = int +IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0) + + +class ZeroconfIPv4Address(IPv4Address): + + __slots__ = ("_str", "_is_link_local", "_is_unspecified") + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize a new IPv4 address.""" + super().__init__(*args, **kwargs) + self._str = super().__str__() + self._is_link_local = super().is_link_local + self._is_unspecified = super().is_unspecified + + def __str__(self) -> str: + """Return the string representation of the IPv4 address.""" + return self._str + + @property + def is_link_local(self) -> bool: + """Return True if this is a link-local address.""" + return self._is_link_local + + @property + def is_unspecified(self) -> bool: + """Return True if this is an unspecified address.""" + return self._is_unspecified + + +class ZeroconfIPv6Address(IPv6Address): + + __slots__ = ("_str", "_is_link_local", "_is_unspecified") + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize a new IPv6 address.""" + super().__init__(*args, **kwargs) + self._str = super().__str__() + self._is_link_local = super().is_link_local + self._is_unspecified = super().is_unspecified + + def __str__(self) -> str: + """Return the string representation of the IPv6 address.""" + return self._str + + @property + def is_link_local(self) -> bool: + """Return True if this is a link-local address.""" + return self._is_link_local + + @property + def is_unspecified(self) -> bool: + """Return True if this is an unspecified address.""" + return self._is_unspecified + + +@lru_cache(maxsize=512) +def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4Address, IPv6Address]]: + """Cache IP addresses.""" + try: + return ZeroconfIPv4Address(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return ZeroconfIPv6Address(address) + except (AddressValueError, NetmaskValueError): + return None + + +cached_ip_addresses_wrapper = _cached_ip_addresses +cached_ip_addresses = cached_ip_addresses_wrapper + + +def get_ip_address_object_from_record(record: DNSAddress) -> Optional[Union[IPv4Address, IPv6Address]]: + """Get the IP address object from the record.""" + if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id is not None: + return ip_bytes_and_scope_to_address(record.address, record.scope_id) + return cached_ip_addresses_wrapper(record.address) + + +def ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Union[IPv4Address, IPv6Address]]: + """Convert the bytes and scope to an IP address object.""" + base_address = cached_ip_addresses_wrapper(address) + if base_address is not None and base_address.is_link_local: + # Avoid expensive __format__ call by using PyUnicode_Join + return cached_ip_addresses_wrapper("".join((str(base_address), "%", str(scope)))) + return base_address + + +def str_without_scope_id(addr: Union[IPv4Address, IPv6Address]) -> str: + """Return the string representation of the address without the scope id.""" + if IPADDRESS_SUPPORTS_SCOPE_ID and addr.version == 6: + address_str = str(addr) + return address_str.partition('%')[0] + return str(addr) + + +__all__ = ( + "cached_ip_addresses", + "get_ip_address_object_from_record", + "ip_bytes_and_scope_to_address", + "str_without_scope_id", +) diff --git a/zeroconf/_utils/name.py b/zeroconf/_utils/name.py new file mode 100644 index 00000000..adccb3e5 --- /dev/null +++ b/zeroconf/_utils/name.py @@ -0,0 +1,177 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from functools import lru_cache +from typing import Set + +from .._exceptions import BadTypeInNameException +from ..const import ( + _HAS_A_TO_Z, + _HAS_ASCII_CONTROL_CHARS, + _HAS_ONLY_A_TO_Z_NUM_HYPHEN, + _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE, + _LOCAL_TRAILER, + _NONTCP_PROTOCOL_LOCAL_TRAILER, + _TCP_PROTOCOL_LOCAL_TRAILER, +) + + +@lru_cache(maxsize=512) +def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: disable=too-many-branches + """ + Validate a fully qualified service name, instance or subtype. [rfc6763] + + Returns fully qualified service name. + + Domain names used by mDNS-SD take the following forms: + + . <_tcp|_udp> . local. + . . <_tcp|_udp> . local. + ._sub . . <_tcp|_udp> . local. + + 1) must end with 'local.' + + This is true because we are implementing mDNS and since the 'm' means + multi-cast, the 'local.' domain is mandatory. + + 2) local is preceded with either '_udp.' or '_tcp.' unless + strict is False + + 3) service name precedes <_tcp|_udp> unless + strict is False + + The rules for Service Names [RFC6335] state that they may be no more + than fifteen characters long (not counting the mandatory underscore), + consisting of only letters, digits, and hyphens, must begin and end + with a letter or digit, must not contain consecutive hyphens, and + must contain at least one letter. + + The instance name and sub type may be up to 63 bytes. + + The portion of the Service Instance Name is a user- + friendly name consisting of arbitrary Net-Unicode text [RFC5198]. It + MUST NOT contain ASCII control characters (byte values 0x00-0x1F and + 0x7F) [RFC20] but otherwise is allowed to contain any characters, + without restriction, including spaces, uppercase, lowercase, + punctuation -- including dots -- accented characters, non-Roman text, + and anything else that may be represented using Net-Unicode. + + :param type_: Type, SubType or service name to validate + :return: fully qualified service name (eg: _http._tcp.local.) + """ + if len(type_) > 256: + # https://datatracker.ietf.org/doc/html/rfc6763#section-7.2 + raise BadTypeInNameException("Full name (%s) must be > 256 bytes" % type_) + + if type_.endswith((_TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER)): + remaining = type_[: -len(_TCP_PROTOCOL_LOCAL_TRAILER)].split('.') + trailer = type_[-len(_TCP_PROTOCOL_LOCAL_TRAILER) :] + has_protocol = True + elif strict: + raise BadTypeInNameException( + "Type '%s' must end with '%s' or '%s'" + % (type_, _TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER) + ) + elif type_.endswith(_LOCAL_TRAILER): + remaining = type_[: -len(_LOCAL_TRAILER)].split('.') + trailer = type_[-len(_LOCAL_TRAILER) + 1 :] + has_protocol = False + else: + raise BadTypeInNameException(f"Type '{type_}' must end with '{_LOCAL_TRAILER}'") + + if strict or has_protocol: + service_name = remaining.pop() + if not service_name: + raise BadTypeInNameException("No Service name found") + + if len(remaining) == 1 and len(remaining[0]) == 0: + raise BadTypeInNameException("Type '%s' must not start with '.'" % type_) + + if service_name[0] != '_': + raise BadTypeInNameException("Service name (%s) must start with '_'" % service_name) + + test_service_name = service_name[1:] + + if strict and len(test_service_name) > 15: + # https://datatracker.ietf.org/doc/html/rfc6763#section-7.2 + raise BadTypeInNameException("Service name (%s) must be <= 15 bytes" % test_service_name) + + if '--' in test_service_name: + raise BadTypeInNameException("Service name (%s) must not contain '--'" % test_service_name) + + if '-' in (test_service_name[0], test_service_name[-1]): + raise BadTypeInNameException( + "Service name (%s) may not start or end with '-'" % test_service_name + ) + + if not _HAS_A_TO_Z.search(test_service_name): + raise BadTypeInNameException( + "Service name (%s) must contain at least one letter (eg: 'A-Z')" % test_service_name + ) + + allowed_characters_re = ( + _HAS_ONLY_A_TO_Z_NUM_HYPHEN if strict else _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE + ) + + if not allowed_characters_re.search(test_service_name): + raise BadTypeInNameException( + "Service name (%s) must contain only these characters: " + "A-Z, a-z, 0-9, hyphen ('-')%s" % (test_service_name, "" if strict else ", underscore ('_')") + ) + else: + service_name = '' + + if remaining and remaining[-1] == '_sub': + remaining.pop() + if len(remaining) == 0 or len(remaining[0]) == 0: + raise BadTypeInNameException("_sub requires a subtype name") + + if len(remaining) > 1: + remaining = ['.'.join(remaining)] + + if remaining: + length = len(remaining[0].encode('utf-8')) + if length > 63: + raise BadTypeInNameException("Too long: '%s'" % remaining[0]) + + if _HAS_ASCII_CONTROL_CHARS.search(remaining[0]): + raise BadTypeInNameException( + "Ascii control character 0x00-0x1F and 0x7F illegal in '%s'" % remaining[0] + ) + + return service_name + trailer + + +def possible_types(name: str) -> Set[str]: + """Build a set of all possible types from a fully qualified name.""" + labels = name.split('.') + label_count = len(labels) + types = set() + for count in range(label_count): + parts = labels[label_count - count - 4 :] + if not parts[0].startswith('_'): + break + types.add('.'.join(parts)) + return types + + +cached_possible_types = lru_cache(maxsize=256)(possible_types) diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py new file mode 100644 index 00000000..dffac896 --- /dev/null +++ b/zeroconf/_utils/net.py @@ -0,0 +1,418 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import enum +import errno +import ipaddress +import socket +import struct +import sys +from typing import Any, List, Optional, Sequence, Tuple, Union, cast + +import dmx.ifaddr as ifaddr + +from .._logger import log +from ..const import _IPPROTO_IPV6, _MDNS_ADDR, _MDNS_ADDR6, _MDNS_PORT + + +@enum.unique +class InterfaceChoice(enum.Enum): + Default = 1 + All = 2 + + +InterfacesType = Union[Sequence[Union[str, int, Tuple[Tuple[str, int, int], int]]], InterfaceChoice] + + +@enum.unique +class ServiceStateChange(enum.Enum): + Added = 1 + Removed = 2 + Updated = 3 + + +@enum.unique +class IPVersion(enum.Enum): + V4Only = 1 + V6Only = 2 + All = 3 + + +# utility functions + + +def _is_v6_address(addr: bytes) -> bool: + return len(addr) == 16 + + +def _encode_address(address: str) -> bytes: + is_ipv6 = ':' in address + address_family = socket.AF_INET6 if is_ipv6 else socket.AF_INET + return socket.inet_pton(address_family, address) + + +def get_all_addresses() -> List[str]: + return list({addr.ip for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv4}) + + +def get_all_addresses_v6() -> List[Tuple[Tuple[str, int, int], int]]: + # IPv6 multicast uses positive indexes for interfaces + # TODO: What about multi-address interfaces? + return list( + {(addr.ip, iface.index) for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv6} + ) + + +def ip6_to_address_and_index(adapters: List[Any], ip: str) -> Tuple[Tuple[str, int, int], int]: + if '%' in ip: + ip = ip[: ip.index('%')] # Strip scope_id. + ipaddr = ipaddress.ip_address(ip) + for adapter in adapters: + for adapter_ip in adapter.ips: + # IPv6 addresses are represented as tuples + if isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr: + return (cast(Tuple[str, int, int], adapter_ip.ip), cast(int, adapter.index)) + + raise RuntimeError('No adapter found for IP address %s' % ip) + + +def interface_index_to_ip6_address(adapters: List[Any], index: int) -> Tuple[str, int, int]: + for adapter in adapters: + if adapter.index == index: + for adapter_ip in adapter.ips: + # IPv6 addresses are represented as tuples + if isinstance(adapter_ip.ip, tuple): + return cast(Tuple[str, int, int], adapter_ip.ip) + + raise RuntimeError('No adapter found for index %s' % index) + + +def ip6_addresses_to_indexes( + interfaces: Sequence[Union[str, int, Tuple[Tuple[str, int, int], int]]] +) -> List[Tuple[Tuple[str, int, int], int]]: + """Convert IPv6 interface addresses to interface indexes. + + IPv4 addresses are ignored. + + :param interfaces: List of IP addresses and indexes. + :returns: List of indexes. + """ + result = [] + adapters = ifaddr.get_adapters() + + for iface in interfaces: + if isinstance(iface, int): + result.append((interface_index_to_ip6_address(adapters, iface), iface)) + elif isinstance(iface, str) and ipaddress.ip_address(iface).version == 6: + result.append(ip6_to_address_and_index(adapters, iface)) + + return result + + +def normalize_interface_choice( + choice: InterfacesType, ip_version: IPVersion = IPVersion.V4Only +) -> List[Union[str, Tuple[Tuple[str, int, int], int]]]: + """Convert the interfaces choice into internal representation. + + :param choice: `InterfaceChoice` or list of interface addresses or indexes (IPv6 only). + :param ip_address: IP version to use (ignored if `choice` is a list). + :returns: List of IP addresses (for IPv4) and indexes (for IPv6). + """ + result: List[Union[str, Tuple[Tuple[str, int, int], int]]] = [] + if choice is InterfaceChoice.Default: + if ip_version != IPVersion.V4Only: + # IPv6 multicast uses interface 0 to mean the default + result.append((('', 0, 0), 0)) + if ip_version != IPVersion.V6Only: + result.append('0.0.0.0') + elif choice is InterfaceChoice.All: + if ip_version != IPVersion.V4Only: + result.extend(get_all_addresses_v6()) + if ip_version != IPVersion.V6Only: + result.extend(get_all_addresses()) + if not result: + raise RuntimeError( + 'No interfaces to listen on, check that any interfaces have IP version %s' % ip_version + ) + elif isinstance(choice, list): + # First, take IPv4 addresses. + result = [i for i in choice if isinstance(i, str) and ipaddress.ip_address(i).version == 4] + # Unlike IP_ADD_MEMBERSHIP, IPV6_JOIN_GROUP requires interface indexes. + result += ip6_addresses_to_indexes(choice) + else: + raise TypeError("choice must be a list or InterfaceChoice, got %r" % choice) + return result + + +def disable_ipv6_only_or_raise(s: socket.socket) -> None: + """Make V6 sockets work for both V4 and V6 (required for Windows).""" + try: + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_V6ONLY, False) + except OSError: + log.error('Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6') + raise + + +def set_so_reuseport_if_available(s: socket.socket) -> None: + """Set SO_REUSEADDR on a socket if available.""" + # SO_REUSEADDR should be equivalent to SO_REUSEPORT for + # multicast UDP sockets (p 731, "TCP/IP Illustrated, + # Volume 2"), but some BSD-derived systems require + # SO_REUSEPORT to be specified explicitly. Also, not all + # versions of Python have SO_REUSEPORT available. + # Catch OSError and socket.error for kernel versions <3.9 because lacking + # SO_REUSEPORT support. + if not hasattr(socket, 'SO_REUSEPORT'): + return + + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # pylint: disable=no-member + except OSError as err: + if err.errno != errno.ENOPROTOOPT: + raise + + +def set_mdns_port_socket_options_for_ip_version( + s: socket.socket, bind_addr: Union[Tuple[str], Tuple[str, int, int]], ip_version: IPVersion +) -> None: + """Set ttl/hops and loop for mdns port.""" + if ip_version != IPVersion.V6Only: + ttl = struct.pack(b'B', 255) + loop = struct.pack(b'B', 1) + # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and + # IP_MULTICAST_LOOP socket options as an unsigned char. + try: + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop) + except OSError as e: + if bind_addr[0] != '' or get_errno(e) != errno.EINVAL: # Fails to set on MacOS + raise + + if ip_version != IPVersion.V4Only: + # However, char doesn't work here (at least on Linux) + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True) + + +def new_socket( + bind_addr: Union[Tuple[str], Tuple[str, int, int]], + port: int = _MDNS_PORT, + ip_version: IPVersion = IPVersion.V4Only, + apple_p2p: bool = False, +) -> Optional[socket.socket]: + log.debug( + 'Creating new socket with port %s, ip_version %s, apple_p2p %s and bind_addr %r', + port, + ip_version, + apple_p2p, + bind_addr, + ) + socket_family = socket.AF_INET if ip_version == IPVersion.V4Only else socket.AF_INET6 + s = socket.socket(socket_family, socket.SOCK_DGRAM) + + if ip_version == IPVersion.All: + disable_ipv6_only_or_raise(s) + + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + set_so_reuseport_if_available(s) + + if port == _MDNS_PORT: + set_mdns_port_socket_options_for_ip_version(s, bind_addr, ip_version) + + if apple_p2p: + # SO_RECV_ANYIF = 0x1104 + # https://opensource.apple.com/source/xnu/xnu-4570.41.2/bsd/sys/socket.h + s.setsockopt(socket.SOL_SOCKET, 0x1104, 1) + + bind_tup = (bind_addr[0], port, *bind_addr[1:]) + try: + s.bind(bind_tup) + except OSError as ex: + if ex.errno == errno.EADDRNOTAVAIL: + log.warning( + 'Address not available when binding to %s, ' 'it is expected to happen on some systems', + bind_tup, + ) + return None + raise + log.debug('Created socket %s', s) + return s + + +def add_multicast_member( + listen_socket: socket.socket, + interface: Union[str, Tuple[Tuple[str, int, int], int]], +) -> bool: + # This is based on assumptions in normalize_interface_choice + is_v6 = isinstance(interface, tuple) + err_einval = {errno.EINVAL} + if sys.platform == 'win32': + # No WSAEINVAL definition in typeshed + err_einval |= {cast(Any, errno).WSAEINVAL} # pylint: disable=no-member + log.debug('Adding %r (socket %d) to multicast group', interface, listen_socket.fileno()) + try: + if is_v6: + try: + mdns_addr6_bytes = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6) + except OSError: + log.info( + 'Unable to translate IPv6 address when adding %s to multicast group, ' + 'this can happen if IPv6 is disabled on the system', + interface, + ) + return False + iface_bin = struct.pack('@I', cast(int, interface[1])) + _value = mdns_addr6_bytes + iface_bin + listen_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, _value) + else: + _value = socket.inet_aton(_MDNS_ADDR) + socket.inet_aton(cast(str, interface)) + listen_socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value) + except OSError as e: + _errno = get_errno(e) + if _errno == errno.EADDRINUSE: + log.info( + 'Address in use when adding %s to multicast group, ' + 'it is expected to happen on some systems', + interface, + ) + return False + if _errno == errno.EADDRNOTAVAIL: + log.info( + 'Address not available when adding %s to multicast ' + 'group, it is expected to happen on some systems', + interface, + ) + return False + if _errno in err_einval: + log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface) + return False + if _errno == errno.ENOPROTOOPT: + log.info( + 'Failed to set socket option on %s, this can happen if ' + 'the network adapter is in a disconnected state', + interface, + ) + return False + if is_v6 and _errno == errno.ENODEV: + log.info( + 'Address in use when adding %s to multicast group, ' + 'it is expected to happen when the device does not have ipv6', + interface, + ) + return False + raise + return True + + +def new_respond_socket( + interface: Union[str, Tuple[Tuple[str, int, int], int]], + apple_p2p: bool = False, +) -> Optional[socket.socket]: + is_v6 = isinstance(interface, tuple) + respond_socket = new_socket( + ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), + apple_p2p=apple_p2p, + bind_addr=cast(Tuple[Tuple[str, int, int], int], interface)[0] if is_v6 else (cast(str, interface),), + ) + if not respond_socket: + return None + log.debug('Configuring socket %s with multicast interface %s', respond_socket, interface) + if is_v6: + iface_bin = struct.pack('@I', cast(int, interface[1])) + respond_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, iface_bin) + else: + respond_socket.setsockopt( + socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(cast(str, interface)) + ) + return respond_socket + + +def create_sockets( + interfaces: InterfacesType = InterfaceChoice.All, + unicast: bool = False, + ip_version: IPVersion = IPVersion.V4Only, + apple_p2p: bool = False, +) -> Tuple[Optional[socket.socket], List[socket.socket]]: + if unicast: + listen_socket = None + else: + listen_socket = new_socket(ip_version=ip_version, apple_p2p=apple_p2p, bind_addr=('',)) + + normalized_interfaces = normalize_interface_choice(interfaces, ip_version) + + # If we are using InterfaceChoice.Default we can use + # a single socket to listen and respond. + if not unicast and interfaces is InterfaceChoice.Default: + for i in normalized_interfaces: + add_multicast_member(cast(socket.socket, listen_socket), i) + return listen_socket, [cast(socket.socket, listen_socket)] + + respond_sockets = [] + + for i in normalized_interfaces: + if not unicast: + if add_multicast_member(cast(socket.socket, listen_socket), i): + respond_socket = new_respond_socket(i, apple_p2p=apple_p2p) + else: + respond_socket = None + else: + respond_socket = new_socket( + port=0, + ip_version=ip_version, + apple_p2p=apple_p2p, + bind_addr=i[0] if isinstance(i, tuple) else (i,), + ) + + if respond_socket is not None: + respond_sockets.append(respond_socket) + + return listen_socket, respond_sockets + + +def get_errno(e: Exception) -> int: + assert isinstance(e, socket.error) + return cast(int, e.args[0]) + + +def can_send_to(ipv6_socket: bool, address: str) -> bool: + """Check if the address type matches the socket type. + + This function does not validate if the address is a valid + ipv6 or ipv4 address. + """ + return ":" in address if ipv6_socket else ":" not in address + + +def autodetect_ip_version(interfaces: InterfacesType) -> IPVersion: + """Auto detect the IP version when it is not provided.""" + if isinstance(interfaces, list): + has_v6 = any( + isinstance(i, int) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6) + for i in interfaces + ) + has_v4 = any(isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in interfaces) + if has_v4 and has_v6: + return IPVersion.All + if has_v6: + return IPVersion.V6Only + + return IPVersion.V4Only diff --git a/zeroconf/_utils/time.py b/zeroconf/_utils/time.py new file mode 100644 index 00000000..600d9028 --- /dev/null +++ b/zeroconf/_utils/time.py @@ -0,0 +1,42 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + + +import time + +_float = float + + +def current_time_millis() -> _float: + """Current time in milliseconds. + + The current implemention uses `time.monotonic` + but may change in the future. + + The design requires the time to match asyncio.loop.time() + """ + return time.monotonic() * 1000 + + +def millis_to_seconds(millis: _float) -> _float: + """Convert milliseconds to seconds.""" + return millis / 1000.0 diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py new file mode 100644 index 00000000..cfe3693e --- /dev/null +++ b/zeroconf/asyncio.py @@ -0,0 +1,277 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" +import asyncio +import contextlib +from types import TracebackType # noqa # used in type hints +from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union + +from ._core import Zeroconf +from ._dns import DNSQuestionType +from ._services import ServiceListener +from ._services.browser import _ServiceBrowserBase +from ._services.info import ServiceInfo +from ._services.types import ZeroconfServiceTypes +from ._utils.net import InterfaceChoice, InterfacesType, IPVersion +from .const import _BROWSER_TIME, _MDNS_PORT, _SERVICE_TYPE_ENUMERATION_NAME + +__all__ = [ + "AsyncZeroconf", + "AsyncServiceInfo", + "AsyncServiceBrowser", + "AsyncZeroconfServiceTypes", +] + + +class AsyncServiceInfo(ServiceInfo): + """An async version of ServiceInfo.""" + + +class AsyncServiceBrowser(_ServiceBrowserBase): + """Used to browse for a service for specific type(s). + + Constructor parameters are as follows: + + * `zc`: A Zeroconf instance + * `type_`: fully qualified service type name + * `handler`: ServiceListener or Callable that knows how to process ServiceStateChange events + * `listener`: ServiceListener + * `addr`: address to send queries (will default to multicast) + * `port`: port to send queries (will default to mdns 5353) + * `delay`: The initial delay between answering questions + * `question_type`: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU) + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability. + """ + + def __init__( + self, + zeroconf: 'Zeroconf', + type_: Union[str, list], + handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[ServiceListener] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + question_type: Optional[DNSQuestionType] = None, + ) -> None: + super().__init__(zeroconf, type_, handlers, listener, addr, port, delay, question_type) + self._async_start() + + async def async_cancel(self) -> None: + """Cancel the browser.""" + self._async_cancel() + + async def __aenter__(self) -> 'AsyncServiceBrowser': + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + await self.async_cancel() + return None + + +class AsyncZeroconfServiceTypes(ZeroconfServiceTypes): + """An async version of ZeroconfServiceTypes.""" + + @classmethod + async def async_find( + cls, + aiozc: Optional['AsyncZeroconf'] = None, + timeout: Union[int, float] = 5, + interfaces: InterfacesType = InterfaceChoice.All, + ip_version: Optional[IPVersion] = None, + ) -> Tuple[str, ...]: + """ + Return all of the advertised services on any local networks. + + :param aiozc: AsyncZeroconf() instance. Pass in if already have an + instance running or if non-default interfaces are needed + :param timeout: seconds to wait for any responses + :param interfaces: interfaces to listen on. + :param ip_version: IP protocol version to use. + :return: tuple of service type strings + """ + local_zc = aiozc or AsyncZeroconf(interfaces=interfaces, ip_version=ip_version) + listener = cls() + async_browser = AsyncServiceBrowser( + local_zc.zeroconf, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener + ) + + # wait for responses + await asyncio.sleep(timeout) + + await async_browser.async_cancel() + + # close down anything we opened + if aiozc is None: + await local_zc.async_close() + + return tuple(sorted(listener.found_services)) + + +class AsyncZeroconf: + """Implementation of Zeroconf Multicast DNS Service Discovery + + Supports registration, unregistration, queries and browsing. + + The async version is currently a wrapper around Zeroconf which + is now also async. It is expected that an asyncio event loop + is already running before creating the AsyncZeroconf object. + """ + + def __init__( + self, + interfaces: InterfacesType = InterfaceChoice.All, + unicast: bool = False, + ip_version: Optional[IPVersion] = None, + apple_p2p: bool = False, + zc: Optional[Zeroconf] = None, + ) -> None: + """Creates an instance of the Zeroconf class, establishing + multicast communications, and listening. + + :param interfaces: :class:`InterfaceChoice` or a list of IP addresses + (IPv4 and IPv6) and interface indexes (IPv6 only). + + IPv6 notes for non-POSIX systems: + * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` + on Python versions before 3.8. + + Also listening on loopback (``::1``) doesn't work, use a real address. + :param ip_version: IP versions to support. If `choice` is a list, the default is detected + from it. Otherwise defaults to V4 only for backward compatibility. + :param apple_p2p: use AWDL interface (only macOS) + """ + self.zeroconf = zc or Zeroconf( + interfaces=interfaces, + unicast=unicast, + ip_version=ip_version, + apple_p2p=apple_p2p, + ) + self.async_browsers: Dict[ServiceListener, AsyncServiceBrowser] = {} + + async def async_register_service( + self, + info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, + cooperating_responders: bool = False, + strict: bool = True, + ) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. The name of the service may be changed if needed to make + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`). + + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. + """ + return await self.zeroconf.async_register_service( + info, ttl, allow_name_change, cooperating_responders, strict + ) + + async def async_unregister_all_services(self) -> None: + """Unregister all registered services. + + Unlike async_register_service and async_unregister_service, this + method does not return a future and is always expected to be + awaited since its only called at shutdown. + """ + await self.zeroconf.async_unregister_all_services() + + async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: + """Unregister a service. + + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. + """ + return await self.zeroconf.async_unregister_service(info) + + async def async_update_service(self, info: ServiceInfo) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. + + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. + """ + return await self.zeroconf.async_update_service(info) + + async def async_close(self) -> None: + """Ends the background threads, and prevent this instance from + servicing further queries.""" + if not self.zeroconf.done: + with contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(self.zeroconf.async_wait_for_start(), timeout=1) + await self.async_remove_all_service_listeners() + await self.async_unregister_all_services() + await self.zeroconf._async_close() # pylint: disable=protected-access + + async def async_get_service_info( + self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None + ) -> Optional[AsyncServiceInfo]: + """Returns network's service information for a particular + name and type, or None if no service matches by the timeout, + which defaults to 3 seconds.""" + info = AsyncServiceInfo(type_, name) + if await info.async_request(self.zeroconf, timeout, question_type): + return info + return None + + async def async_add_service_listener(self, type_: str, listener: ServiceListener) -> None: + """Adds a listener for a particular service type. This object + will then have its add_service and remove_service methods called when + services of that type become available and unavailable.""" + await self.async_remove_service_listener(listener) + self.async_browsers[listener] = AsyncServiceBrowser(self.zeroconf, type_, listener) + + async def async_remove_service_listener(self, listener: ServiceListener) -> None: + """Removes a listener from the set that is currently listening.""" + if listener in self.async_browsers: + await self.async_browsers[listener].async_cancel() + del self.async_browsers[listener] + + async def async_remove_all_service_listeners(self) -> None: + """Removes a listener from the set that is currently listening.""" + await asyncio.gather( + *(self.async_remove_service_listener(listener) for listener in list(self.async_browsers)) + ) + + async def __aenter__(self) -> 'AsyncZeroconf': + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + await self.async_close() + return None diff --git a/zeroconf/const.py b/zeroconf/const.py new file mode 100644 index 00000000..aa64306e --- /dev/null +++ b/zeroconf/const.py @@ -0,0 +1,163 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import re +import socket + +# Some timing constants + +_UNREGISTER_TIME = 125 # ms +_CHECK_TIME = 175 # ms +_REGISTER_TIME = 225 # ms +_LISTENER_TIME = 200 # ms +_BROWSER_TIME = 10000 # ms +_DUPLICATE_PACKET_SUPPRESSION_INTERVAL = 1000 # ms +_DUPLICATE_QUESTION_INTERVAL = 999 # ms # Must be 1ms less than _DUPLICATE_PACKET_SUPPRESSION_INTERVAL +_CACHE_CLEANUP_INTERVAL = 10 # s +_LOADED_SYSTEM_TIMEOUT = 10 # s +_STARTUP_TIMEOUT = 9 # s must be lower than _LOADED_SYSTEM_TIMEOUT +_ONE_SECOND = 1000 # ms + +# If the system is loaded or the event +# loop was blocked by another task that was doing I/O in the loop +# (shouldn't happen but it does in practice) we need to give +# a buffer timeout to ensure a coroutine can finish before +# the future times out + +# Some DNS constants + +_MDNS_ADDR = '224.0.0.251' +_MDNS_ADDR6 = 'ff02::fb' +_MDNS_PORT = 5353 +_DNS_PORT = 53 +_DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762 +_DNS_OTHER_TTL = 4500 # 75 minutes for non-host records (PTR, TXT etc) as-per RFC6762 +# Currently we enforce a minimum TTL for PTR records to avoid +# ServiceBrowsers generating excessive queries refresh queries. +# Apple uses a 15s minimum TTL, however we do not have the same +# level of rate limit and safe guards so we use 1/4 of the recommended value +_DNS_PTR_MIN_TTL = _DNS_OTHER_TTL / 4 + +_DNS_PACKET_HEADER_LEN = 12 + +_MAX_MSG_TYPICAL = 1460 # unused +_MAX_MSG_ABSOLUTE = 8966 + +_FLAGS_QR_MASK = 0x8000 # query response mask +_FLAGS_QR_QUERY = 0x0000 # query +_FLAGS_QR_RESPONSE = 0x8000 # response + +_FLAGS_AA = 0x0400 # Authoritative answer +_FLAGS_TC = 0x0200 # Truncated +_FLAGS_RD = 0x0100 # Recursion desired +_FLAGS_RA = 0x8000 # Recursion available + +_FLAGS_Z = 0x0040 # Zero +_FLAGS_AD = 0x0020 # Authentic data +_FLAGS_CD = 0x0010 # Checking disabled + +_CLASS_IN = 1 +_CLASS_CS = 2 +_CLASS_CH = 3 +_CLASS_HS = 4 +_CLASS_NONE = 254 +_CLASS_ANY = 255 +_CLASS_MASK = 0x7FFF +_CLASS_UNIQUE = 0x8000 +_CLASS_IN_UNIQUE = _CLASS_IN | _CLASS_UNIQUE + +_TYPE_A = 1 +_TYPE_NS = 2 +_TYPE_MD = 3 +_TYPE_MF = 4 +_TYPE_CNAME = 5 +_TYPE_SOA = 6 +_TYPE_MB = 7 +_TYPE_MG = 8 +_TYPE_MR = 9 +_TYPE_NULL = 10 +_TYPE_WKS = 11 +_TYPE_PTR = 12 +_TYPE_HINFO = 13 +_TYPE_MINFO = 14 +_TYPE_MX = 15 +_TYPE_TXT = 16 +_TYPE_AAAA = 28 +_TYPE_SRV = 33 +_TYPE_NSEC = 47 +_TYPE_ANY = 255 + +# Mapping constants to names + +_CLASSES = { + _CLASS_IN: "in", + _CLASS_CS: "cs", + _CLASS_CH: "ch", + _CLASS_HS: "hs", + _CLASS_NONE: "none", + _CLASS_ANY: "any", +} + +_TYPES = { + _TYPE_A: "a", + _TYPE_NS: "ns", + _TYPE_MD: "md", + _TYPE_MF: "mf", + _TYPE_CNAME: "cname", + _TYPE_SOA: "soa", + _TYPE_MB: "mb", + _TYPE_MG: "mg", + _TYPE_MR: "mr", + _TYPE_NULL: "null", + _TYPE_WKS: "wks", + _TYPE_PTR: "ptr", + _TYPE_HINFO: "hinfo", + _TYPE_MINFO: "minfo", + _TYPE_MX: "mx", + _TYPE_TXT: "txt", + _TYPE_AAAA: "quada", + _TYPE_SRV: "srv", + _TYPE_ANY: "any", + _TYPE_NSEC: "nsec", +} + +_ADDRESS_RECORD_TYPES = {_TYPE_A, _TYPE_AAAA} + +_HAS_A_TO_Z = re.compile(r'[A-Za-z]') +_HAS_ONLY_A_TO_Z_NUM_HYPHEN = re.compile(r'^[A-Za-z0-9\-]+$') +_HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$') +_HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]') + +_EXPIRE_REFRESH_TIME_PERCENT = 75 + +_LOCAL_TRAILER = '.local.' +_TCP_PROTOCOL_LOCAL_TRAILER = '._tcp.local.' +_NONTCP_PROTOCOL_LOCAL_TRAILER = '._udp.local.' + +# https://datatracker.ietf.org/doc/html/rfc6763#section-9 +_SERVICE_TYPE_ENUMERATION_NAME = "_services._dns-sd._udp.local." + +try: + _IPPROTO_IPV6 = socket.IPPROTO_IPV6 +except AttributeError: + # Sigh: https://bugs.python.org/issue29515 + _IPPROTO_IPV6 = 41 diff --git a/zeroconf/py.typed b/zeroconf/py.typed new file mode 100644 index 00000000..e69de29b