Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning for the length of the group name #2122

Merged
merged 12 commits into from
Jan 28, 2025
54 changes: 26 additions & 28 deletions channels/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,34 +144,32 @@ def match_type_and_length(self, name):
invalid_name_error = (
"{} name must be a valid unicode string "
+ "with length < {} ".format(MAX_NAME_LENGTH)
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods, "
+ "not {}"
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods."
)

def valid_channel_name(self, name, receive=False):
if self.match_type_and_length(name):
if bool(self.channel_name_regex.match(name)):
# Check cases for special channels
if "!" in name and not name.endswith("!") and receive:
raise TypeError(
"Specific channel names in receive() must end at the !"
)
return True
raise TypeError(self.invalid_name_error.format("Channel", name))

def valid_group_name(self, name):
if self.match_type_and_length(name):
if bool(self.group_name_regex.match(name)):
return True
raise TypeError(self.invalid_name_error.format("Group", name))
def require_valid_channel_name(self, name, receive=False):
if not self.match_type_and_length(name):
raise TypeError(self.invalid_name_error.format("Channel"))
if not bool(self.channel_name_regex.match(name)):
raise TypeError(self.invalid_name_error.format("Channel"))
if "!" in name and not name.endswith("!") and receive:
raise TypeError("Specific channel names in receive() must end at the !")
return True

def require_valid_group_name(self, name):
if not self.match_type_and_length(name):
raise TypeError(self.invalid_name_error.format("Group"))
if not bool(self.group_name_regex.match(name)):
raise TypeError(self.invalid_name_error.format("Group"))
return True
IronJam11 marked this conversation as resolved.
Show resolved Hide resolved

def valid_channel_names(self, names, receive=False):
_non_empty_list = True if names else False
_names_type = isinstance(names, list)
assert _non_empty_list and _names_type, "names must be a non-empty list"

assert all(
self.valid_channel_name(channel, receive=receive) for channel in names
all(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The all doesn't have semantic value, can you refactor this as a normal for loop?

self.require_valid_channel_name(channel, receive=receive)
for channel in names
)
return True

Expand Down Expand Up @@ -243,7 +241,7 @@ async def send(self, channel, message):
"""
# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
self.require_valid_channel_name(channel)
# If it's a process-local channel, strip off local part and stick full
# name in message
assert "__asgi_channel__" not in message
Expand All @@ -263,7 +261,7 @@ async def receive(self, channel):
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
assert self.valid_channel_name(channel)
self.require_valid_channel_name(channel)
self._clean_expired()

queue = self.channels.setdefault(
Expand Down Expand Up @@ -341,16 +339,16 @@ async def group_add(self, group, channel):
Adds the channel name to a group.
"""
# Check the inputs
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
self.require_valid_group_name(group)
self.require_valid_channel_name(channel)
# Add to group dict
self.groups.setdefault(group, {})
self.groups[group][channel] = time.time()

async def group_discard(self, group, channel):
# Both should be text and valid
assert self.valid_channel_name(channel), "Invalid channel name"
assert self.valid_group_name(group), "Invalid group name"
self.require_valid_channel_name(channel)
self.require_valid_group_name(group)
# Remove from group set
group_channels = self.groups.get(group, None)
if group_channels:
Expand All @@ -363,7 +361,7 @@ async def group_discard(self, group, channel):
async def group_send(self, group, message):
# Check types
assert isinstance(message, dict), "Message is not a dict"
assert self.valid_group_name(group), "Invalid group name"
self.require_valid_group_name(group)
# Run clean
self._clean_expired()

Expand Down
41 changes: 40 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ async def test_send_receive():

@pytest.mark.parametrize(
"method",
[BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name],
[
BaseChannelLayer().require_valid_channel_name,
BaseChannelLayer().require_valid_group_name,
],
)
@pytest.mark.parametrize(
"channel_name,expected_valid",
Expand All @@ -84,3 +87,39 @@ def test_channel_and_group_name_validation(method, channel_name, expected_valid)
else:
with pytest.raises(TypeError):
method(channel_name)


@pytest.mark.parametrize(
"name",
[
"a" * 101, # Group name too long
],
)
def test_group_name_length_error_message(name):
"""
Ensure the correct error message is raised when group names
exceed the character limit or contain invalid characters.
"""
layer = BaseChannelLayer()
expected_error_message = layer.invalid_name_error.format("Group")

with pytest.raises(TypeError, match=expected_error_message):
layer.require_valid_group_name(name)


@pytest.mark.parametrize(
"name",
[
"a" * 101, # Channel name too long
],
)
def test_channel_name_length_error_message(name):
"""
Ensure the correct error message is raised when group names
exceed the character limit or contain invalid characters.
"""
layer = BaseChannelLayer()
expected_error_message = layer.invalid_name_error.format("Channel")

with pytest.raises(TypeError, match=expected_error_message):
layer.require_valid_channel_name(name)