Skip to content

Commit

Permalink
Fix EndpointDescriptor
Browse files Browse the repository at this point in the history
Fix creation of EndpointDescriptor and add unit tests for both possible
lengths.
  • Loading branch information
twam committed Feb 19, 2022
1 parent a983073 commit f16173f
Showing 1 changed file with 71 additions and 10 deletions.
81 changes: 71 additions & 10 deletions usb_protocol/types/descriptors/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,23 @@ class DeviceCapabilityTypes(IntEnum):
)


EndpointDescriptor = DescriptorFormat(
EndpointDescriptorLength = construct.Rebuild(construct.Int8ul, 7 if (this.bRefresh is None) and (this.bSynchAddress is None) else 9)

EndpointDescriptor = DescriptorFormat(
# [USB2.0: 9.6; USB Audio Device Class Definition 1.0: 4.6.1.1, 4.6.2.1]
# Interfaces of the Audio 1.0 class extend their subordinate endpoint descriptors with
# 2 additional bytes (extending it from 7 to 9 bytes). Thankfully, this is the only extension that
# changes the length of a standard descriptor type, but we do have to handle this case in Construct.
"bLength" / construct.Default(construct.OneOf(construct.Int8ul, [7, 9]), 7),
"bLength" / EndpointDescriptorLength,
"bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.ENDPOINT),
"bEndpointAddress" / DescriptorField("Endpoint Address"),
"bmAttributes" / DescriptorField("Attributes", default=2),
"wMaxPacketSize" / DescriptorField("Maximum Packet Size", default=64),
"bInterval" / DescriptorField("Polling interval", default=255),

# 2 bytes that are only present on endpoint descriptors for Audio 1.0 class interfaces.
("bRefresh" / construct.Optional(construct.Int8ul)) * "Refresh Rate",
("bSynchAddress" / construct.Optional(construct.Int8ul)) * "Synch Endpoint Address",
("bRefresh" / construct.If(this.bLength == 9, construct.Int8ul)) * "Refresh Rate",
("bSynchAddress" / construct.If(this.bLength == 9, construct.Int8ul)) * "Synch Endpoint Address",
)


Expand Down Expand Up @@ -198,7 +199,6 @@ class DeviceCapabilityTypes(IntEnum):
)



class DescriptorParserCases(unittest.TestCase):

STRING_DESCRIPTOR = bytes([
Expand All @@ -225,7 +225,6 @@ class DescriptorParserCases(unittest.TestCase):
ord('s'), 0x00,
])


def test_string_descriptor_parse(self):

# Parse the relevant string...
Expand All @@ -236,23 +235,20 @@ def test_string_descriptor_parse(self):
self.assertEqual(parsed.bDescriptorType, 3)
self.assertEqual(parsed.bString, "Great Scott Gadgets")


def test_string_descriptor_build(self):
data = StringDescriptor.build({
'bString': "Great Scott Gadgets"
})

self.assertEqual(data, self.STRING_DESCRIPTOR)


def test_string_language_descriptor_build(self):
data = StringLanguageDescriptor.build({
'wLANGID': (LanguageIDs.ENGLISH_US,)
})

self.assertEqual(data, b"\x04\x03\x09\x04")


def test_device_descriptor(self):

device_descriptor = [
Expand Down Expand Up @@ -291,7 +287,6 @@ def test_device_descriptor(self):
self.assertEqual(parsed.iSerialNumber, 3)
self.assertEqual(parsed.bNumConfigurations, 1)


def test_bcd_constructor(self):

emitter = BCDFieldAdapter(construct.Int16ul)
Expand All @@ -300,5 +295,71 @@ def test_bcd_constructor(self):
self.assertEqual(result, b"\x40\x01")


def test_parse_endpoint_descriptor(self):
# Parse the relevant descriptor ...
parsed = EndpointDescriptor.parse([
0x07, # Length
0x05, # Type
0x81, # Endpoint address
0x02, # Attributes
0x40, 0x00, # Maximum packet size
0xFF, # Interval
])

# ... and check the descriptor's fields.
self.assertEqual(parsed.bLength, 7)
self.assertEqual(parsed.bDescriptorType, StandardDescriptorNumbers.ENDPOINT)
self.assertEqual(parsed.bEndpointAddress, 0x81)
self.assertEqual(parsed.bmAttributes, 2)
self.assertEqual(parsed.wMaxPacketSize, 64)
self.assertEqual(parsed.bInterval, 255)

def test_parse_endpoint_descriptor_audio(self):
# Parse the relevant descriptor ...
parsed = EndpointDescriptor.parse([
0x09, # Length
0x05, # Type
0x81, # Endpoint address
0x02, # Attributes
0x40, 0x00, # Maximum packet size
0xFF, # Interval
0x20, # Refresh rate
0x05, # Synch endpoint address
])

# ... and check the descriptor's fields.
self.assertEqual(parsed.bLength, 9)
self.assertEqual(parsed.bDescriptorType, StandardDescriptorNumbers.ENDPOINT)
self.assertEqual(parsed.bEndpointAddress, 0x81)
self.assertEqual(parsed.bmAttributes, 2)
self.assertEqual(parsed.wMaxPacketSize, 64)
self.assertEqual(parsed.bInterval, 255)
self.assertEqual(parsed.bRefresh, 32)
self.assertEqual(parsed.bSynchAddress, 0x05)

def test_build_endpoint_descriptor_audio(self):
# Build the relevant descriptor
data = EndpointDescriptor.build({
'bEndpointAddress': 0x81,
'bmAttributes': 2,
'wMaxPacketSize': 64,
'bInterval': 255,
'bRefresh': 32,
'bSynchAddress': 0x05,
})

# ... and check the binary output
self.assertEqual(data, bytes([
0x09, # Length
0x05, # Type
0x81, # Endpoint address
0x02, # Attributes
0x40, 0x00, # Maximum packet size
0xFF, # Interval
0x20, # Refresh rate
0x05, # Synch endpoint address
]))


if __name__ == "__main__":
unittest.main()

0 comments on commit f16173f

Please sign in to comment.