Skip to content

Commit

Permalink
Catch invalid organization id preventing 500s (#4475)
Browse files Browse the repository at this point in the history
* eeej small files

* return -1 if org id is invalid type

---------

Co-authored-by: Katherine Fleming <[email protected]>
  • Loading branch information
perryr16 and kflemin authored Jan 18, 2024
1 parent 04f66bb commit ac67f1f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
15 changes: 9 additions & 6 deletions seed/lib/superperms/orgs/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,15 @@ def get_org_or_id(dictlike: dict) -> Union[int, None]:
# Check if there are any assigned organization values
org_id = None
for org_str in org_query_strings:
org_id = dictlike.get(org_str)
if org_id:
# Type case the organization_id as a integer
if '_id' in org_str:
org_id = int(org_id)
break
try:
org_id = dictlike.get(org_str)
if org_id:
# Type case the organization_id as a integer
if '_id' in org_str:
org_id = int(org_id)
break
except (ValueError, TypeError):
return -1
return org_id


Expand Down
21 changes: 21 additions & 0 deletions seed/tests/test_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,27 @@ def test_get_org_id(self):
result = get_org_id(mock_request)
self.assertEqual(None, result)

# invalid ids are returned as -1 (not found)
mock_request = mock_request_factory(
view_authz_org_id_kwarg=None,
parser_kwargs={'not_org_id': 1},
path='/api/v3/nope/2/',
query_params={'organization_id': 'invalid_id'},
data={'organization_id': 4}
)
result = get_org_id(mock_request)
self.assertEqual(-1, result)

mock_request = mock_request_factory(
view_authz_org_id_kwarg=None,
parser_kwargs={'not_org_id': 1},
path='/api/v3/nope/2/',
query_params={'not_org_id': 2},
data={'organization_id': 'invalid_id'}
)
result = get_org_id(mock_request)
self.assertEqual(-1, result)

def test_get_user_org(self):
"""Test getting org from user"""
fake_user = User.objects.create(username='test')
Expand Down
7 changes: 7 additions & 0 deletions seed/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,6 +1560,13 @@ def test_get_cycles(self):
self.assertEqual(cycle['id'], self.cycle.pk)
self.assertEqual(cycle['name'], self.cycle.name)

# invalid organization id returns 403 error
params['organization_id'] = 'invalid'
response = self.client.get(
reverse('api:v3:cycles-list'), params
)
self.assertEqual(403, response.status_code)

def test_postoffice(self):
# Create a template
response = self.client.post('/api/v3/postoffice/', {
Expand Down

0 comments on commit ac67f1f

Please sign in to comment.