diff --git a/channels/layers.py b/channels/layers.py index 99e7fbd6..5fc53f74 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -144,35 +144,31 @@ def match_type_and_length(self, name): invalid_name_error = ( "{} name must be a valid unicode string " + "with length < {} ".format(MAX_NAME_LENGTH) - + "containing only ASCII alphanumerics, hyphens, underscores, or periods, " - + "not {}" + + "containing only ASCII alphanumerics, hyphens, underscores, or periods." ) - def valid_channel_name(self, name, receive=False): - if self.match_type_and_length(name): - if bool(self.channel_name_regex.match(name)): - # Check cases for special channels - if "!" in name and not name.endswith("!") and receive: - raise TypeError( - "Specific channel names in receive() must end at the !" - ) - return True - raise TypeError(self.invalid_name_error.format("Channel", name)) - - def valid_group_name(self, name): - if self.match_type_and_length(name): - if bool(self.group_name_regex.match(name)): - return True - raise TypeError(self.invalid_name_error.format("Group", name)) + def require_valid_channel_name(self, name, receive=False): + if not self.match_type_and_length(name): + raise TypeError(self.invalid_name_error.format("Channel")) + if not bool(self.channel_name_regex.match(name)): + raise TypeError(self.invalid_name_error.format("Channel")) + if "!" in name and not name.endswith("!") and receive: + raise TypeError("Specific channel names in receive() must end at the !") + return True + + def require_valid_group_name(self, name): + if not self.match_type_and_length(name): + raise TypeError(self.invalid_name_error.format("Group")) + if not bool(self.group_name_regex.match(name)): + raise TypeError(self.invalid_name_error.format("Group")) + return True def valid_channel_names(self, names, receive=False): _non_empty_list = True if names else False _names_type = isinstance(names, list) assert _non_empty_list and _names_type, "names must be a non-empty list" - - assert all( - self.valid_channel_name(channel, receive=receive) for channel in names - ) + for channel in names: + self.require_valid_channel_name(channel, receive=receive) return True def non_local_name(self, name): @@ -243,7 +239,7 @@ async def send(self, channel, message): """ # Typecheck assert isinstance(message, dict), "message is not a dict" - assert self.valid_channel_name(channel), "Channel name not valid" + self.require_valid_channel_name(channel) # If it's a process-local channel, strip off local part and stick full # name in message assert "__asgi_channel__" not in message @@ -263,7 +259,7 @@ async def receive(self, channel): If more than one coroutine waits on the same channel, a random one of the waiting coroutines will get the result. """ - assert self.valid_channel_name(channel) + self.require_valid_channel_name(channel) self._clean_expired() queue = self.channels.setdefault( @@ -341,16 +337,16 @@ async def group_add(self, group, channel): Adds the channel name to a group. """ # Check the inputs - assert self.valid_group_name(group), "Group name not valid" - assert self.valid_channel_name(channel), "Channel name not valid" + self.require_valid_group_name(group) + self.require_valid_channel_name(channel) # Add to group dict self.groups.setdefault(group, {}) self.groups[group][channel] = time.time() async def group_discard(self, group, channel): # Both should be text and valid - assert self.valid_channel_name(channel), "Invalid channel name" - assert self.valid_group_name(group), "Invalid group name" + self.require_valid_channel_name(channel) + self.require_valid_group_name(group) # Remove from group set group_channels = self.groups.get(group, None) if group_channels: @@ -363,7 +359,7 @@ async def group_discard(self, group, channel): async def group_send(self, group, message): # Check types assert isinstance(message, dict), "Message is not a dict" - assert self.valid_group_name(group), "Invalid group name" + self.require_valid_group_name(group) # Run clean self._clean_expired() diff --git a/tests/test_layers.py b/tests/test_layers.py index 543a9f19..7b02c155 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -72,7 +72,10 @@ async def test_send_receive(): @pytest.mark.parametrize( "method", - [BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name], + [ + BaseChannelLayer().require_valid_channel_name, + BaseChannelLayer().require_valid_group_name, + ], ) @pytest.mark.parametrize( "channel_name,expected_valid", @@ -84,3 +87,39 @@ def test_channel_and_group_name_validation(method, channel_name, expected_valid) else: with pytest.raises(TypeError): method(channel_name) + + +@pytest.mark.parametrize( + "name", + [ + "a" * 101, # Group name too long + ], +) +def test_group_name_length_error_message(name): + """ + Ensure the correct error message is raised when group names + exceed the character limit or contain invalid characters. + """ + layer = BaseChannelLayer() + expected_error_message = layer.invalid_name_error.format("Group") + + with pytest.raises(TypeError, match=expected_error_message): + layer.require_valid_group_name(name) + + +@pytest.mark.parametrize( + "name", + [ + "a" * 101, # Channel name too long + ], +) +def test_channel_name_length_error_message(name): + """ + Ensure the correct error message is raised when group names + exceed the character limit or contain invalid characters. + """ + layer = BaseChannelLayer() + expected_error_message = layer.invalid_name_error.format("Channel") + + with pytest.raises(TypeError, match=expected_error_message): + layer.require_valid_channel_name(name)