Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rationalise accessing of read-only device attributes by use of a BlinkStickDevice class #110

Merged
merged 5 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions src/blinkstick/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,33 @@

from typing import TypeVar, Generic

from blinkstick.devices.device import BlinkStickDevice

T = TypeVar("T")


class BaseBackend(ABC, Generic[T]):

serial: str | None
blinkstick_device: BlinkStickDevice[T]

def __init__(self):
self.serial = None

@abstractmethod
def _refresh_device(self):
def _refresh_attached_blinkstick_device(self):
raise NotImplementedError

@staticmethod
@abstractmethod
def find_blinksticks(find_all: bool = True) -> list[T] | None:
def get_attached_blinkstick_devices(
find_all: bool = True,
) -> list[BlinkStickDevice[T]]:
raise NotImplementedError

@staticmethod
@abstractmethod
def find_by_serial(serial: str) -> list[T] | None:
def find_by_serial(serial: str) -> list[BlinkStickDevice[T]] | None:
raise NotImplementedError

@abstractmethod
Expand All @@ -39,18 +44,14 @@ def control_transfer(
):
raise NotImplementedError

@abstractmethod
def get_serial(self) -> str:
raise NotImplementedError
return self.blinkstick_device.serial

@abstractmethod
def get_manufacturer(self) -> str:
raise NotImplementedError
return self.blinkstick_device.manufacturer

@abstractmethod
def get_version_attribute(self) -> int:
raise NotImplementedError
return self.blinkstick_device.version_attribute

@abstractmethod
def get_description(self) -> str:
raise NotImplementedError
def get_description(self):
return self.blinkstick_device.description
83 changes: 45 additions & 38 deletions src/blinkstick/backends/unix_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,55 +5,70 @@

from blinkstick.constants import VENDOR_ID, PRODUCT_ID
from blinkstick.backends.base import BaseBackend
from blinkstick.devices.device import BlinkStickDevice
from blinkstick.exceptions import BlinkStickException


class UnixLikeBackend(BaseBackend[usb.core.Device]):

serial: str
device: usb.core.Device
blinkstick_device: BlinkStickDevice[usb.core.Device]

def __init__(self, device=None):
self.device = device
self.blinkstick_device = device
super().__init__()
if device:
self.open_device()
self.serial = self.get_serial()

def open_device(self) -> None:
if self.device is None:
if self.blinkstick_device is None:
raise BlinkStickException("Could not find BlinkStick...")

if self.device.is_kernel_driver_active(0):
if self.blinkstick_device.raw_device.is_kernel_driver_active(0):
try:
self.device.detach_kernel_driver(0)
self.blinkstick_device.raw_device.detach_kernel_driver(0)
except usb.core.USBError as e:
raise BlinkStickException("Could not detach kernel driver: %s" % str(e))

def _refresh_device(self):
if not self.serial:
def _refresh_attached_blinkstick_device(self):
if not self.blinkstick_device:
return False
if devices := self.find_by_serial(self.serial):
self.device = devices[0]
if devices := self.find_by_serial(self.blinkstick_device.serial):
self.blinkstick_device = devices[0]
self.open_device()
return True

@staticmethod
def find_blinksticks(find_all: bool = True) -> list[usb.core.Device] | None:
return usb.core.find(
find_all=find_all, idVendor=VENDOR_ID, idProduct=PRODUCT_ID
def get_attached_blinkstick_devices(
find_all: bool = True,
) -> list[BlinkStickDevice[usb.core.Device]]:
raw_devices = (
usb.core.find(find_all=find_all, idVendor=VENDOR_ID, idProduct=PRODUCT_ID)
or []
)
return [
# TODO: refactor this to DRY up the usb.util.get_string calls
# note that we can't use _usb_get_string here because we're not in an instance method
# and we don't have a BlinkStickDevice instance to call it on
# until then we'll just have to live with the duplication, and the fact that we're not able
# to handle USB errors in the same way as we do in the instance methods
BlinkStickDevice(
raw_device=device,
serial=str(usb.util.get_string(device, 3, 1033)),
manufacturer=str(usb.util.get_string(device, 1, 1033)),
version_attribute=device.bcdDevice,
description=str(usb.util.get_string(device, 2, 1033)),
)
for device in raw_devices
]

@staticmethod
def find_by_serial(serial: str) -> list[usb.core.Device] | None:
found_devices = UnixLikeBackend.find_blinksticks() or []
def find_by_serial(serial: str) -> list[BlinkStickDevice[usb.core.Device]] | None:
found_devices = UnixLikeBackend.get_attached_blinkstick_devices()
for d in found_devices:
try:
if usb.util.get_string(d, 3, 1033) == serial:
devices = [d]
return devices
except Exception as e:
print("{0}".format(e))
if d.serial == serial:
return [d]

return None

Expand All @@ -66,15 +81,15 @@ def control_transfer(
data_or_wLength: bytes | int,
):
try:
return self.device.ctrl_transfer(
return self.blinkstick_device.raw_device.ctrl_transfer(
bmRequestType, bRequest, wValue, wIndex, data_or_wLength
)
except usb.USBError:
# Could not communicate with BlinkStick backend
# attempt to find it again based on serial

if self._refresh_device():
return self.device.ctrl_transfer(
if self._refresh_attached_blinkstick_device():
return self.blinkstick_device.raw_device.ctrl_transfer(
bmRequestType, bRequest, wValue, wIndex, data_or_wLength
)
else:
Expand All @@ -84,27 +99,19 @@ def control_transfer(
)
)

def get_serial(self) -> str:
return self._usb_get_string(3)

def get_manufacturer(self) -> str:
return self._usb_get_string(1)

def get_version_attribute(self) -> int:
return int(self.device.bcdDevice)

def get_description(self):
return self._usb_get_string(2)

def _usb_get_string(self, index: int) -> str:
try:
return str(usb.util.get_string(self.device, index, 1033))
return str(
usb.util.get_string(self.blinkstick_device.raw_device, index, 1033)
)
except usb.USBError:
# Could not communicate with BlinkStick backend
# attempt to find it again based on serial

if self._refresh_device():
return str(usb.util.get_string(self.device, index, 1033))
if self._refresh_attached_blinkstick_device():
return str(
usb.util.get_string(self.blinkstick_device.raw_device, index, 1033)
)
else:
raise BlinkStickException(
"Could not communicate with BlinkStick {0} - it may have been removed".format(
Expand Down
71 changes: 35 additions & 36 deletions src/blinkstick/backends/win32.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,64 @@

from blinkstick.constants import VENDOR_ID, PRODUCT_ID
from blinkstick.backends.base import BaseBackend
from blinkstick.devices.device import BlinkStickDevice
from blinkstick.exceptions import BlinkStickException


class Win32Backend(BaseBackend[hid.HidDevice]):
serial: str
device: hid.HidDevice
blinkstick_device: BlinkStickDevice[hid.HidDevice]
reports: list[hid.core.HidReport]

def __init__(self, device=None):
def __init__(self, device: BlinkStickDevice[hid.HidDevice]):
super().__init__()
self.device = device
self.blinkstick_device = device
if device:
self.device.open()
self.reports = self.device.find_feature_reports()
self.blinkstick_device.raw_device.open()
self.reports = self.blinkstick_device.raw_device.find_feature_reports()
self.serial = self.get_serial()

@staticmethod
def find_by_serial(serial: str) -> list[hid.HidDevice] | None:
found_devices = Win32Backend.find_blinksticks() or []
devices = [d for d in found_devices if d.serial_number == serial]

if len(devices) > 0:
return devices
def find_by_serial(serial: str) -> list[BlinkStickDevice[hid.HidDevice]] | None:
found_devices = Win32Backend.get_attached_blinkstick_devices()
for d in found_devices:
if d.serial == serial:
return [d]

return None

def _refresh_device(self):
def _refresh_attached_blinkstick_device(self):
# TODO This is weird semantics. fix up return values to be more sensible
if not self.serial:
return False
if devices := self.find_by_serial(self.serial):
self.device = devices[0]
self.device.open()
self.reports = self.device.find_feature_reports()
self.blinkstick_device = devices[0]
self.blinkstick_device.raw_device.open()
self.reports = self.blinkstick_device.raw_device.find_feature_reports()
return True

@staticmethod
def find_blinksticks(find_all: bool = True) -> list[hid.HidDevice] | None:
def get_attached_blinkstick_devices(
find_all: bool = True,
) -> list[BlinkStickDevice[hid.HidDevice]]:
devices = hid.HidDeviceFilter(
vendor_id=VENDOR_ID, product_id=PRODUCT_ID
).get_devices()

blinkstick_devices = [
BlinkStickDevice(
raw_device=device,
serial=device.serial_number,
manufacturer=device.vendor_name,
version_attribute=device.version_number,
description=device.product_name,
)
for device in devices
]
if find_all:
return devices
elif len(devices) > 0:
return devices[0]
else:
return None
return blinkstick_devices

return blinkstick_devices[:1]

def control_transfer(
self, bmRequestType, bRequest, wValue, wIndex, data_or_wLength
Expand All @@ -68,9 +79,9 @@ def control_transfer(
*[c_ubyte(c) for c in data_or_wLength]
)
data[0] = wValue
if not self.device.send_feature_report(data):
if self._refresh_device():
self.device.send_feature_report(data)
if not self.blinkstick_device.raw_device.send_feature_report(data):
if self._refresh_attached_blinkstick_device():
self.blinkstick_device.raw_device.send_feature_report(data)
else:
raise BlinkStickException(
"Could not communicate with BlinkStick {0} - it may have been removed".format(
Expand All @@ -80,15 +91,3 @@ def control_transfer(

elif bmRequestType == 0x80 | 0x20:
return self.reports[wValue - 1].get()

def get_serial(self) -> str:
return str(self.device.serial_number)

def get_manufacturer(self) -> str:
return str(self.device.vendor_name)

def get_version_attribute(self) -> int:
return int(self.device.version_number)

def get_description(self) -> str:
return str(self.device.product_name)
13 changes: 8 additions & 5 deletions src/blinkstick/blinkstick.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ColorFormat,
)
from blinkstick.constants import VENDOR_ID, PRODUCT_ID, BlinkStickVariant
from blinkstick.devices.device import BlinkStickDevice
from blinkstick.exceptions import BlinkStickException
from blinkstick.utilities import string_to_info_block_data

Expand Down Expand Up @@ -51,7 +52,9 @@ class BlinkStick:
backend: USBBackend
bs_serial: str

def __init__(self, device=None, error_reporting: bool = True):
def __init__(
self, device: BlinkStickDevice | None = None, error_reporting: bool = True
):
"""
Constructor for the class.

Expand Down Expand Up @@ -1403,7 +1406,7 @@ def find_all() -> list[BlinkStick]:
@return: a list of BlinkStick objects or None if no devices found
"""
result: list[BlinkStick] = []
if (found_devices := USBBackend.find_blinksticks()) is None:
if (found_devices := USBBackend.get_attached_blinkstick_devices()) is None:
return result
for d in found_devices:
result.extend([BlinkStick(device=d)])
Expand All @@ -1418,10 +1421,10 @@ def find_first() -> BlinkStick | None:
@rtype: BlinkStick
@return: BlinkStick object or None if no devices are found
"""
d = USBBackend.find_blinksticks(find_all=False)
blinkstick_devices = USBBackend.get_attached_blinkstick_devices(find_all=False)

if d:
return BlinkStick(device=d)
if blinkstick_devices:
return BlinkStick(device=blinkstick_devices[0])

return None

Expand Down
15 changes: 15 additions & 0 deletions src/blinkstick/devices/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass
from typing import Generic, TypeVar

T = TypeVar("T")


@dataclass
class BlinkStickDevice(Generic[T]):
"""A BlinkStick device representation"""

raw_device: T
serial: str
manufacturer: str
version_attribute: int
description: str