diff --git a/api/institutions/views.py b/api/institutions/views.py index 2cfd357d9d5..9d1ec5dde89 100644 --- a/api/institutions/views.py +++ b/api/institutions/views.py @@ -34,7 +34,7 @@ ) from api.base.settings import DEFAULT_ES_NULL_VALUE from api.metrics.permissions import IsInstitutionalMetricsUser -from api.metrics.renderers import MetricsReportsCsvRenderer, MetricsReportsTsvRenderer +from api.metrics.renderers import MetricsReportsCsvRenderer, MetricsReportsTsvRenderer, MetricsReportsJsonRenderer from api.nodes.serializers import NodeSerializer from api.nodes.filters import NodesFilterMixin from api.users.serializers import UserSerializer @@ -558,6 +558,7 @@ class _NewInstitutionUserMetricsList(InstitutionMixin, ElasticsearchListView): *api_settings.DEFAULT_RENDERER_CLASSES, MetricsReportsCsvRenderer, MetricsReportsTsvRenderer, + MetricsReportsJsonRenderer, ) serializer_class = NewInstitutionUserMetricsSerializer diff --git a/api/metrics/renderers.py b/api/metrics/renderers.py index d38a57d0de8..fc3196db7ad 100644 --- a/api/metrics/renderers.py +++ b/api/metrics/renderers.py @@ -1,6 +1,7 @@ import csv import io -from api.base.settings.defaults import USER_INSTITUTION_REPORT_FILENAME +import json +from api.base.settings.defaults import USER_INSTITUTION_REPORT_FILENAME, MAX_SIZE_OF_ES_QUERY import datetime from django.http import Http404 @@ -18,16 +19,25 @@ def csv_fieldname_sortkey(fieldname): def get_nested_keys(report_attrs): - for attr_key in sorted(report_attrs.keys(), key=csv_fieldname_sortkey): - attr_value = report_attrs[attr_key] - if isinstance(attr_value, dict): - for subkey in get_nested_keys(attr_value): - yield f'{attr_key}.{subkey}' - else: - yield attr_key + """ + Recursively retrieves all nested keys from the report attributes. + Handles both dictionaries and lists of attributes. + """ + if isinstance(report_attrs, dict): + for attr_key in sorted(report_attrs.keys(), key=csv_fieldname_sortkey): + attr_value = report_attrs[attr_key] + if isinstance(attr_value, dict): + for subkey in get_nested_keys(attr_value): + yield f'{attr_key}.{subkey}' + else: + yield attr_key + elif isinstance(report_attrs, list): + for item in report_attrs: + yield from get_nested_keys(item) def get_key_value(nested_key, report_attrs): + report_attrs = report_attrs.to_dict() if hasattr(report_attrs, 'to_dict') else report_attrs (key, _, next_nested_key) = nested_key.partition('.') attr_value = report_attrs.get(key, {}) return ( @@ -62,21 +72,30 @@ def get_filename(self, renderer_context: dict, format_type: str) -> str: else: raise NotImplementedError('Missing format filename') - def render(self, json_response: dict, accepted_media_type: str = None, renderer_context: dict = None) -> str: - """Render the response as CSV or TSV format.""" - serialized_reports = (jsonapi_resource['attributes'] for jsonapi_resource in json_response['data']) + def get_all_data(self, view, request): + """Bypass pagination by fetching all the data.""" + view.pagination_class = None # Disable pagination + return view.get_default_search().extra(size=MAX_SIZE_OF_ES_QUERY).execute() - try: - first_row = next(serialized_reports) - except StopIteration: + def render(self, data: dict, accepted_media_type: str = None, renderer_context: dict = None) -> str: + """Render the full dataset as CSV or TSV format.""" + data = self.get_all_data(renderer_context['view'], renderer_context['request']) + hits = data.hits + if not hits: raise Http404('

none found

') - csv_fieldnames = list(get_nested_keys(first_row)) + + # Assuming each hit contains '_source' with the relevant data + first_row = hits[0].to_dict() + csv_fieldnames = list(first_row) csv_filecontent = io.StringIO(newline='') csv_writer = csv.writer(csv_filecontent, dialect=self.CSV_DIALECT) csv_writer.writerow(csv_fieldnames) - for serialized_report in (first_row, *serialized_reports): - csv_writer.writerow(get_csv_row(csv_fieldnames, serialized_report)) + # Write each hit's '_source' as a row in the CSV + for hit in hits: + csv_writer.writerow(get_csv_row(csv_fieldnames, hit.to_dict())) + + # Set response headers for file download response = renderer_context['response'] filename = self.get_filename(renderer_context, self.extension) response['Content-Disposition'] = f'attachment; filename="{filename}"' @@ -96,3 +115,32 @@ class MetricsReportsTsvRenderer(MetricsReportsBaseRenderer): format = 'tsv' CSV_DIALECT = csv.excel_tab extension = 'tsv' + + +class MetricsReportsJsonRenderer(MetricsReportsBaseRenderer): + media_type = 'application/json' + format = 'json_file' + extension = 'json' + + def default_serializer(self, obj): + """Custom serializer to handle non-serializable objects like datetime.""" + if isinstance(obj, datetime.datetime): + return obj.isoformat() # Convert datetime to ISO format string + raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable') + + def render(self, data, accepted_media_type=None, renderer_context=None): + """Render the response as JSON format and trigger browser download as a binary file.""" + data = self.get_all_data(renderer_context['view'], renderer_context['request']) + hits = data.hits + if not hits: + raise Http404('

none found

') + + serialized_hits = [hit.to_dict() for hit in hits] + + # Set response headers for file download + response = renderer_context['response'] + filename = self.get_filename(renderer_context, self.extension) + response['Content-Disposition'] = f'attachment; filename="{filename}"' + + # Use custom serializer for non-serializable types (like datetime) + return json.dumps(serialized_hits, default=self.default_serializer, indent=4).encode('utf-8') diff --git a/api_tests/institutions/views/test_institution_user_metric_list.py b/api_tests/institutions/views/test_institution_user_metric_list.py index 7e2fb17885d..84b73ef74d1 100644 --- a/api_tests/institutions/views/test_institution_user_metric_list.py +++ b/api_tests/institutions/views/test_institution_user_metric_list.py @@ -1,5 +1,6 @@ -import datetime import csv +import json +import datetime from io import StringIO from random import random from urllib.parse import urlencode @@ -406,16 +407,16 @@ def test_paginate_reports(self, app, url, institutional_admin, institution, repo @pytest.mark.parametrize('format_type, delimiter, content_type', [ ('csv', ',', 'text/csv; charset=utf-8'), - ('tsv', '\t', 'text/tab-separated-values; charset=utf-8') + ('tsv', '\t', 'text/tab-separated-values; charset=utf-8'), + ('json_file', None, 'application/json; charset=utf-8') ]) def test_get_report_formats(self, app, url, institutional_admin, institution, format_type, delimiter, content_type): - # Setting up the reports _report_factory( '2024-08', institution, - user_id='u_orcomma', + user_id=f'u_orcomma', account_creation_date='2018-02', - user_name='Jason Kelce', + user_name=f'Jason Kelce', orcid_id='4444-3333-2222-1111', department_name='Center \t Greatest Ever', storage_byte_count=736662999298, @@ -427,24 +428,6 @@ def test_get_report_formats(self, app, url, institutional_admin, institution, fo private_project_count=5, month_last_active='2018-02', month_last_login='2018-02', - ), - _report_factory( - '2024-08', - institution, - user_id='u_orcomma2', - account_creation_date='2018-02', - user_name='Brian Dawkins, Weapon X, The Wolverine', - orcid_id='4444-3333-2222-1111', - department_name='Safety', - storage_byte_count=736662999298, - embargoed_registration_count=1, - published_preprint_count=1, - public_registration_count=2, - public_project_count=3, - public_file_count=4, - private_project_count=5, - month_last_active='2018-02', - month_last_login='2018-02', ) resp = app.get(f'{url}?format={format_type}', auth=institutional_admin.auth) @@ -455,65 +438,47 @@ def test_get_report_formats(self, app, url, institutional_admin, institution, fo expected_filename = USER_INSTITUTION_REPORT_FILENAME.format( date_created=current_date, institution_id=institution._id, - format_type=format_type + format_type='json' if format_type == 'json_file' else format_type ) assert resp.headers['Content-Disposition'] == f'attachment; filename="{expected_filename}"' - response_body = resp.text - expected_response = [ - [ # Column Headers - 'account_creation_date', - 'department', - 'embargoed_registration_count', - 'month_last_active', - 'month_last_login', - 'orcid_id', - 'private_projects', - 'public_file_count', - 'public_projects', - 'public_registration_count', - 'published_preprint_count', - 'storage_byte_count', - 'user_name' - ], - [ - '2018-02', - 'Center \t Greatest Ever', - '1', - '2018-02', - '2018-02', - '4444-3333-2222-1111', - '5', - '4', - '3', - '2', - '1', - '736662999298', - 'Jason Kelce' - ], - [ - '2018-02', - 'Safety', - '1', - '2018-02', - '2018-02', - '4444-3333-2222-1111', - '5', - '4', - '3', - '2', - '1', - '736662999298', - 'Brian Dawkins, Weapon X, The Wolverine' + if format_type == 'json_file': + # Validate JSON structure and content + response_data = json.loads(resp.body.decode('utf-8')) + expected_data = [ + { + 'account_creation_date': '2018-02', + 'department_name': 'Center \t Greatest Ever', + 'embargoed_registration_count': 1, + 'month_last_active': '2018-02', + 'month_last_login': '2018-02', + 'orcid_id': '4444-3333-2222-1111', + 'private_project_count': 5, + 'public_file_count': 4, + 'public_project_count': 3, + 'public_registration_count': 2, + 'published_preprint_count': 1, + 'storage_byte_count': 736662999298, + 'user_name': 'Jason Kelce' + } + ] + assert response_data == expected_data + else: + response_body = resp.text + expected_response = [ + ['account_creation_date', 'department_name', 'embargoed_registration_count', 'month_last_active', + 'month_last_login', 'orcid_id', 'private_projects', 'public_file_count', 'public_projects', + 'public_registration_count', 'published_preprint_count', 'storage_byte_count', 'user_name'], + ['2018-02', 'Center \t Greatest Ever', '1', '2018-02', '2018-02', '4444-3333-2222-1111', '5', '4', '3', + '2', '1', '736662999298', 'Jason Kelce'], ] - ] - - with StringIO(response_body) as file: - reader = csv.reader(file, delimiter=delimiter) - response_rows = list(reader) - assert response_rows[0] == expected_response[0] - assert sorted(response_rows[1:]) == sorted(expected_response[1:]) + if delimiter: + with StringIO(response_body) as file: + reader = csv.reader(file, delimiter=delimiter) + response_rows = list(reader) + assert response_rows[0] == expected_response[0] + assert sorted(response_rows[1:]) == sorted(expected_response[1:]) def _user_ids(api_response):