diff --git a/api/base/settings/defaults.py b/api/base/settings/defaults.py index ebf48d07433..a6ecb6b8e3f 100644 --- a/api/base/settings/defaults.py +++ b/api/base/settings/defaults.py @@ -359,6 +359,7 @@ MAX_SIZE_OF_ES_QUERY = 10000 DEFAULT_ES_NULL_VALUE = 'N/A' +USER_INSTITUTION_REPORT_FILENAME = 'institution_user_report_{institution_id}_{date_created}.{format_type}' CI_ENV = False diff --git a/api/metrics/renderers.py b/api/metrics/renderers.py index fd4bdc78da2..d38a57d0de8 100644 --- a/api/metrics/renderers.py +++ b/api/metrics/renderers.py @@ -1,5 +1,7 @@ import csv import io +from api.base.settings.defaults import USER_INSTITUTION_REPORT_FILENAME +import datetime from django.http import Http404 @@ -42,16 +44,28 @@ def get_csv_row(keys_list, report_attrs): ] -class MetricsReportsCsvRenderer(renderers.BaseRenderer): - media_type = 'text/csv' - format = 'csv' - CSV_DIALECT = csv.excel +class MetricsReportsBaseRenderer(renderers.BaseRenderer): + media_type: str + format: str + CSV_DIALECT: csv.Dialect + extension: str + + def get_filename(self, renderer_context: dict, format_type: str) -> str: + """Generate the filename for the CSV/TSV file based on institution and current date.""" + if renderer_context and 'view' in renderer_context: + current_date = datetime.datetime.now().strftime('%Y-%m') # Format as 'YYYY-MM' + return USER_INSTITUTION_REPORT_FILENAME.format( + date_created=current_date, + institution_id=renderer_context['view'].kwargs['institution_id'], + format_type=format_type + ) + else: + raise NotImplementedError('Missing format filename') + + def render(self, json_response: dict, accepted_media_type: str = None, renderer_context: dict = None) -> str: + """Render the response as CSV or TSV format.""" + serialized_reports = (jsonapi_resource['attributes'] for jsonapi_resource in json_response['data']) - def render(self, json_response, accepted_media_type=None, renderer_context=None): - serialized_reports = ( - jsonapi_resource['attributes'] - for jsonapi_resource in json_response['data'] - ) try: first_row = next(serialized_reports) except StopIteration: @@ -61,13 +75,24 @@ def render(self, json_response, accepted_media_type=None, renderer_context=None) csv_writer = csv.writer(csv_filecontent, dialect=self.CSV_DIALECT) csv_writer.writerow(csv_fieldnames) for serialized_report in (first_row, *serialized_reports): - csv_writer.writerow( - get_csv_row(csv_fieldnames, serialized_report), - ) + csv_writer.writerow(get_csv_row(csv_fieldnames, serialized_report)) + + response = renderer_context['response'] + filename = self.get_filename(renderer_context, self.extension) + response['Content-Disposition'] = f'attachment; filename="{filename}"' + return csv_filecontent.getvalue() -class MetricsReportsTsvRenderer(MetricsReportsCsvRenderer): - format = 'tsv' +class MetricsReportsCsvRenderer(MetricsReportsBaseRenderer): + media_type = 'text/csv' + format = 'csv' + CSV_DIALECT = csv.excel + extension = 'csv' + + +class MetricsReportsTsvRenderer(MetricsReportsBaseRenderer): media_type = 'text/tab-separated-values' + format = 'tsv' CSV_DIALECT = csv.excel_tab + extension = 'tsv' diff --git a/api_tests/institutions/views/test_institution_user_metric_list.py b/api_tests/institutions/views/test_institution_user_metric_list.py index cf8004a023b..7e2fb17885d 100644 --- a/api_tests/institutions/views/test_institution_user_metric_list.py +++ b/api_tests/institutions/views/test_institution_user_metric_list.py @@ -7,7 +7,7 @@ import pytest from waffle.testutils import override_flag -from api.base.settings.defaults import API_BASE, DEFAULT_ES_NULL_VALUE +from api.base.settings.defaults import API_BASE, DEFAULT_ES_NULL_VALUE, USER_INSTITUTION_REPORT_FILENAME import osf.features from osf_tests.factories import ( InstitutionFactory, @@ -451,6 +451,14 @@ def test_get_report_formats(self, app, url, institutional_admin, institution, fo assert resp.status_code == 200 assert resp.headers['Content-Type'] == content_type + current_date = datetime.datetime.now().strftime('%Y-%m') + expected_filename = USER_INSTITUTION_REPORT_FILENAME.format( + date_created=current_date, + institution_id=institution._id, + format_type=format_type + ) + assert resp.headers['Content-Disposition'] == f'attachment; filename="{expected_filename}"' + response_body = resp.text expected_response = [ [ # Column Headers