Skip to content

Commit

Permalink
include all results in downloaded data beyond default page size and a…
Browse files Browse the repository at this point in the history
…dd Json attachment renderer.
  • Loading branch information
John Tordoff committed Oct 23, 2024
1 parent e2f7e47 commit 62907cf
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 95 deletions.
3 changes: 2 additions & 1 deletion api/institutions/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from api.base.settings import DEFAULT_ES_NULL_VALUE
from api.metrics.permissions import IsInstitutionalMetricsUser
from api.metrics.renderers import MetricsReportsCsvRenderer, MetricsReportsTsvRenderer
from api.metrics.renderers import MetricsReportsCsvRenderer, MetricsReportsTsvRenderer, MetricsReportsJsonRenderer
from api.nodes.serializers import NodeSerializer
from api.nodes.filters import NodesFilterMixin
from api.users.serializers import UserSerializer
Expand Down Expand Up @@ -558,6 +558,7 @@ class _NewInstitutionUserMetricsList(InstitutionMixin, ElasticsearchListView):
*api_settings.DEFAULT_RENDERER_CLASSES,
MetricsReportsCsvRenderer,
MetricsReportsTsvRenderer,
MetricsReportsJsonRenderer,
)

serializer_class = NewInstitutionUserMetricsSerializer
Expand Down
82 changes: 65 additions & 17 deletions api/metrics/renderers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
import io
from api.base.settings.defaults import USER_INSTITUTION_REPORT_FILENAME
import json
from api.base.settings.defaults import USER_INSTITUTION_REPORT_FILENAME, MAX_SIZE_OF_ES_QUERY
import datetime

from django.http import Http404
Expand All @@ -18,16 +19,25 @@ def csv_fieldname_sortkey(fieldname):


def get_nested_keys(report_attrs):
for attr_key in sorted(report_attrs.keys(), key=csv_fieldname_sortkey):
attr_value = report_attrs[attr_key]
if isinstance(attr_value, dict):
for subkey in get_nested_keys(attr_value):
yield f'{attr_key}.{subkey}'
else:
yield attr_key
"""
Recursively retrieves all nested keys from the report attributes.
Handles both dictionaries and lists of attributes.
"""
if isinstance(report_attrs, dict):
for attr_key in sorted(report_attrs.keys(), key=csv_fieldname_sortkey):
attr_value = report_attrs[attr_key]
if isinstance(attr_value, dict):
for subkey in get_nested_keys(attr_value):
yield f'{attr_key}.{subkey}'
else:
yield attr_key
elif isinstance(report_attrs, list):
for item in report_attrs:
yield from get_nested_keys(item)


def get_key_value(nested_key, report_attrs):
report_attrs = report_attrs.to_dict() if hasattr(report_attrs, 'to_dict') else report_attrs
(key, _, next_nested_key) = nested_key.partition('.')
attr_value = report_attrs.get(key, {})
return (
Expand Down Expand Up @@ -62,21 +72,30 @@ def get_filename(self, renderer_context: dict, format_type: str) -> str:
else:
raise NotImplementedError('Missing format filename')

def render(self, json_response: dict, accepted_media_type: str = None, renderer_context: dict = None) -> str:
"""Render the response as CSV or TSV format."""
serialized_reports = (jsonapi_resource['attributes'] for jsonapi_resource in json_response['data'])
def get_all_data(self, view, request):
"""Bypass pagination by fetching all the data."""
view.pagination_class = None # Disable pagination
return view.get_default_search().extra(size=MAX_SIZE_OF_ES_QUERY).execute()

try:
first_row = next(serialized_reports)
except StopIteration:
def render(self, data: dict, accepted_media_type: str = None, renderer_context: dict = None) -> str:
"""Render the full dataset as CSV or TSV format."""
data = self.get_all_data(renderer_context['view'], renderer_context['request'])
hits = data.hits
if not hits:
raise Http404('<h1>none found</h1>')
csv_fieldnames = list(get_nested_keys(first_row))

# Assuming each hit contains '_source' with the relevant data
first_row = hits[0].to_dict()
csv_fieldnames = list(first_row)
csv_filecontent = io.StringIO(newline='')
csv_writer = csv.writer(csv_filecontent, dialect=self.CSV_DIALECT)
csv_writer.writerow(csv_fieldnames)
for serialized_report in (first_row, *serialized_reports):
csv_writer.writerow(get_csv_row(csv_fieldnames, serialized_report))

# Write each hit's '_source' as a row in the CSV
for hit in hits:
csv_writer.writerow(get_csv_row(csv_fieldnames, hit.to_dict()))

# Set response headers for file download
response = renderer_context['response']
filename = self.get_filename(renderer_context, self.extension)
response['Content-Disposition'] = f'attachment; filename="{filename}"'
Expand All @@ -96,3 +115,32 @@ class MetricsReportsTsvRenderer(MetricsReportsBaseRenderer):
format = 'tsv'
CSV_DIALECT = csv.excel_tab
extension = 'tsv'


class MetricsReportsJsonRenderer(MetricsReportsBaseRenderer):
media_type = 'application/json'
format = 'json_file'
extension = 'json'

def default_serializer(self, obj):
"""Custom serializer to handle non-serializable objects like datetime."""
if isinstance(obj, datetime.datetime):
return obj.isoformat() # Convert datetime to ISO format string
raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')

def render(self, data, accepted_media_type=None, renderer_context=None):
"""Render the response as JSON format and trigger browser download as a binary file."""
data = self.get_all_data(renderer_context['view'], renderer_context['request'])
hits = data.hits
if not hits:
raise Http404('<h1>none found</h1>')

serialized_hits = [hit.to_dict() for hit in hits]

# Set response headers for file download
response = renderer_context['response']
filename = self.get_filename(renderer_context, self.extension)
response['Content-Disposition'] = f'attachment; filename="{filename}"'

# Use custom serializer for non-serializable types (like datetime)
return json.dumps(serialized_hits, default=self.default_serializer, indent=4).encode('utf-8')
119 changes: 42 additions & 77 deletions api_tests/institutions/views/test_institution_user_metric_list.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import csv
import json
import datetime
from io import StringIO
from random import random
from urllib.parse import urlencode
Expand Down Expand Up @@ -406,16 +407,16 @@ def test_paginate_reports(self, app, url, institutional_admin, institution, repo

@pytest.mark.parametrize('format_type, delimiter, content_type', [
('csv', ',', 'text/csv; charset=utf-8'),
('tsv', '\t', 'text/tab-separated-values; charset=utf-8')
('tsv', '\t', 'text/tab-separated-values; charset=utf-8'),
('json_file', None, 'application/json; charset=utf-8')
])
def test_get_report_formats(self, app, url, institutional_admin, institution, format_type, delimiter, content_type):
# Setting up the reports
_report_factory(
'2024-08',
institution,
user_id='u_orcomma',
user_id=f'u_orcomma',
account_creation_date='2018-02',
user_name='Jason Kelce',
user_name=f'Jason Kelce',
orcid_id='4444-3333-2222-1111',
department_name='Center \t Greatest Ever',
storage_byte_count=736662999298,
Expand All @@ -427,24 +428,6 @@ def test_get_report_formats(self, app, url, institutional_admin, institution, fo
private_project_count=5,
month_last_active='2018-02',
month_last_login='2018-02',
),
_report_factory(
'2024-08',
institution,
user_id='u_orcomma2',
account_creation_date='2018-02',
user_name='Brian Dawkins, Weapon X, The Wolverine',
orcid_id='4444-3333-2222-1111',
department_name='Safety',
storage_byte_count=736662999298,
embargoed_registration_count=1,
published_preprint_count=1,
public_registration_count=2,
public_project_count=3,
public_file_count=4,
private_project_count=5,
month_last_active='2018-02',
month_last_login='2018-02',
)

resp = app.get(f'{url}?format={format_type}', auth=institutional_admin.auth)
Expand All @@ -455,65 +438,47 @@ def test_get_report_formats(self, app, url, institutional_admin, institution, fo
expected_filename = USER_INSTITUTION_REPORT_FILENAME.format(
date_created=current_date,
institution_id=institution._id,
format_type=format_type
format_type='json' if format_type == 'json_file' else format_type
)
assert resp.headers['Content-Disposition'] == f'attachment; filename="{expected_filename}"'

response_body = resp.text
expected_response = [
[ # Column Headers
'account_creation_date',
'department',
'embargoed_registration_count',
'month_last_active',
'month_last_login',
'orcid_id',
'private_projects',
'public_file_count',
'public_projects',
'public_registration_count',
'published_preprint_count',
'storage_byte_count',
'user_name'
],
[
'2018-02',
'Center \t Greatest Ever',
'1',
'2018-02',
'2018-02',
'4444-3333-2222-1111',
'5',
'4',
'3',
'2',
'1',
'736662999298',
'Jason Kelce'
],
[
'2018-02',
'Safety',
'1',
'2018-02',
'2018-02',
'4444-3333-2222-1111',
'5',
'4',
'3',
'2',
'1',
'736662999298',
'Brian Dawkins, Weapon X, The Wolverine'
if format_type == 'json_file':
# Validate JSON structure and content
response_data = json.loads(resp.body.decode('utf-8'))
expected_data = [
{
'account_creation_date': '2018-02',
'department_name': 'Center \t Greatest Ever',
'embargoed_registration_count': 1,
'month_last_active': '2018-02',
'month_last_login': '2018-02',
'orcid_id': '4444-3333-2222-1111',
'private_project_count': 5,
'public_file_count': 4,
'public_project_count': 3,
'public_registration_count': 2,
'published_preprint_count': 1,
'storage_byte_count': 736662999298,
'user_name': 'Jason Kelce'
}
]
assert response_data == expected_data
else:
response_body = resp.text
expected_response = [
['account_creation_date', 'department_name', 'embargoed_registration_count', 'month_last_active',
'month_last_login', 'orcid_id', 'private_projects', 'public_file_count', 'public_projects',
'public_registration_count', 'published_preprint_count', 'storage_byte_count', 'user_name'],
['2018-02', 'Center \t Greatest Ever', '1', '2018-02', '2018-02', '4444-3333-2222-1111', '5', '4', '3',
'2', '1', '736662999298', 'Jason Kelce'],
]
]

with StringIO(response_body) as file:
reader = csv.reader(file, delimiter=delimiter)
response_rows = list(reader)

assert response_rows[0] == expected_response[0]
assert sorted(response_rows[1:]) == sorted(expected_response[1:])
if delimiter:
with StringIO(response_body) as file:
reader = csv.reader(file, delimiter=delimiter)
response_rows = list(reader)
assert response_rows[0] == expected_response[0]
assert sorted(response_rows[1:]) == sorted(expected_response[1:])


def _user_ids(api_response):
Expand Down

0 comments on commit 62907cf

Please sign in to comment.