diff --git a/usb_protocol/types/descriptors/standard.py b/usb_protocol/types/descriptors/standard.py index bdf53f1..a97cd36 100644 --- a/usb_protocol/types/descriptors/standard.py +++ b/usb_protocol/types/descriptors/standard.py @@ -128,13 +128,14 @@ class DeviceCapabilityTypes(IntEnum): ) -EndpointDescriptor = DescriptorFormat( +EndpointDescriptorLength = construct.Rebuild(construct.Int8ul, 7 if (this.bRefresh is None) and (this.bSynchAddress is None) else 9) +EndpointDescriptor = DescriptorFormat( # [USB2.0: 9.6; USB Audio Device Class Definition 1.0: 4.6.1.1, 4.6.2.1] # Interfaces of the Audio 1.0 class extend their subordinate endpoint descriptors with # 2 additional bytes (extending it from 7 to 9 bytes). Thankfully, this is the only extension that # changes the length of a standard descriptor type, but we do have to handle this case in Construct. - "bLength" / construct.Default(construct.OneOf(construct.Int8ul, [7, 9]), 7), + "bLength" / EndpointDescriptorLength, "bDescriptorType" / DescriptorNumber(StandardDescriptorNumbers.ENDPOINT), "bEndpointAddress" / DescriptorField("Endpoint Address"), "bmAttributes" / DescriptorField("Attributes", default=2), @@ -142,8 +143,8 @@ class DeviceCapabilityTypes(IntEnum): "bInterval" / DescriptorField("Polling interval", default=255), # 2 bytes that are only present on endpoint descriptors for Audio 1.0 class interfaces. - ("bRefresh" / construct.Optional(construct.Int8ul)) * "Refresh Rate", - ("bSynchAddress" / construct.Optional(construct.Int8ul)) * "Synch Endpoint Address", + ("bRefresh" / construct.If(this.bLength == 9, construct.Int8ul)) * "Refresh Rate", + ("bSynchAddress" / construct.If(this.bLength == 9, construct.Int8ul)) * "Synch Endpoint Address", ) @@ -198,7 +199,6 @@ class DeviceCapabilityTypes(IntEnum): ) - class DescriptorParserCases(unittest.TestCase): STRING_DESCRIPTOR = bytes([ @@ -225,7 +225,6 @@ class DescriptorParserCases(unittest.TestCase): ord('s'), 0x00, ]) - def test_string_descriptor_parse(self): # Parse the relevant string... @@ -236,7 +235,6 @@ def test_string_descriptor_parse(self): self.assertEqual(parsed.bDescriptorType, 3) self.assertEqual(parsed.bString, "Great Scott Gadgets") - def test_string_descriptor_build(self): data = StringDescriptor.build({ 'bString': "Great Scott Gadgets" @@ -244,7 +242,6 @@ def test_string_descriptor_build(self): self.assertEqual(data, self.STRING_DESCRIPTOR) - def test_string_language_descriptor_build(self): data = StringLanguageDescriptor.build({ 'wLANGID': (LanguageIDs.ENGLISH_US,) @@ -252,7 +249,6 @@ def test_string_language_descriptor_build(self): self.assertEqual(data, b"\x04\x03\x09\x04") - def test_device_descriptor(self): device_descriptor = [ @@ -291,7 +287,6 @@ def test_device_descriptor(self): self.assertEqual(parsed.iSerialNumber, 3) self.assertEqual(parsed.bNumConfigurations, 1) - def test_bcd_constructor(self): emitter = BCDFieldAdapter(construct.Int16ul) @@ -300,5 +295,71 @@ def test_bcd_constructor(self): self.assertEqual(result, b"\x40\x01") + def test_parse_endpoint_descriptor(self): + # Parse the relevant descriptor ... + parsed = EndpointDescriptor.parse([ + 0x07, # Length + 0x05, # Type + 0x81, # Endpoint address + 0x02, # Attributes + 0x40, 0x00, # Maximum packet size + 0xFF, # Interval + ]) + + # ... and check the descriptor's fields. + self.assertEqual(parsed.bLength, 7) + self.assertEqual(parsed.bDescriptorType, StandardDescriptorNumbers.ENDPOINT) + self.assertEqual(parsed.bEndpointAddress, 0x81) + self.assertEqual(parsed.bmAttributes, 2) + self.assertEqual(parsed.wMaxPacketSize, 64) + self.assertEqual(parsed.bInterval, 255) + + def test_parse_endpoint_descriptor_audio(self): + # Parse the relevant descriptor ... + parsed = EndpointDescriptor.parse([ + 0x09, # Length + 0x05, # Type + 0x81, # Endpoint address + 0x02, # Attributes + 0x40, 0x00, # Maximum packet size + 0xFF, # Interval + 0x20, # Refresh rate + 0x05, # Synch endpoint address + ]) + + # ... and check the descriptor's fields. + self.assertEqual(parsed.bLength, 9) + self.assertEqual(parsed.bDescriptorType, StandardDescriptorNumbers.ENDPOINT) + self.assertEqual(parsed.bEndpointAddress, 0x81) + self.assertEqual(parsed.bmAttributes, 2) + self.assertEqual(parsed.wMaxPacketSize, 64) + self.assertEqual(parsed.bInterval, 255) + self.assertEqual(parsed.bRefresh, 32) + self.assertEqual(parsed.bSynchAddress, 0x05) + + def test_build_endpoint_descriptor_audio(self): + # Build the relevant descriptor + data = EndpointDescriptor.build({ + 'bEndpointAddress': 0x81, + 'bmAttributes': 2, + 'wMaxPacketSize': 64, + 'bInterval': 255, + 'bRefresh': 32, + 'bSynchAddress': 0x05, + }) + + # ... and check the binary output + self.assertEqual(data, bytes([ + 0x09, # Length + 0x05, # Type + 0x81, # Endpoint address + 0x02, # Attributes + 0x40, 0x00, # Maximum packet size + 0xFF, # Interval + 0x20, # Refresh rate + 0x05, # Synch endpoint address + ])) + + if __name__ == "__main__": unittest.main()