Skip to content

Commit

Permalink
rework pagination and tests to ignore page size when using elastic
Browse files Browse the repository at this point in the history
  • Loading branch information
John Tordoff committed Nov 4, 2024
1 parent 2f7487b commit 822317e
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 9 deletions.
16 changes: 16 additions & 0 deletions api/base/elasticsearch_dsl_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@

from api.base.filters import FilterMixin
from api.base.views import JSONAPIBaseView
from api.metrics.renderers import (
MetricsReportsCsvRenderer,
MetricsReportsTsvRenderer,
)
from api.base.pagination import ElasticsearchListViewPagination, JSONAPIPagination


class ElasticsearchListView(FilterMixin, JSONAPIBaseView, generics.ListAPIView, abc.ABC):
Expand Down Expand Up @@ -39,6 +44,10 @@ def get_default_search(self) -> edsl.Search | None:
# override FilterMixin to disable all operators besides 'eq' and 'ne'
MATCHABLE_FIELDS = ()
COMPARABLE_FIELDS = ()
FILE_RENDERER_CLASSES = {
MetricsReportsCsvRenderer,
MetricsReportsTsvRenderer,
}
DEFAULT_OPERATOR_OVERRIDES = {}
# (if you want to add fulltext-search or range-filter support, remove the override
# and update `__add_search_filter` to handle those operators -- tho note that the
Expand All @@ -52,6 +61,13 @@ def get_default_search(self) -> edsl.Search | None:
# it works fine with default pagination

# override rest_framework.generics.GenericAPIView
@property
def pagination_class(self):
if any(self.request.accepted_renderer.format == renderer.format for renderer in self.FILE_RENDERER_CLASSES):
return ElasticsearchListViewPagination
else:
return JSONAPIPagination

def get_queryset(self):
_search = self.get_default_search()
if _search is None:
Expand Down
9 changes: 8 additions & 1 deletion api/base/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
replace_query_param, remove_query_param,
)
from api.base.serializers import is_anonymized
from api.base.settings import MAX_PAGE_SIZE
from api.base.settings import MAX_PAGE_SIZE, MAX_SIZE_OF_ES_QUERY
from api.base.utils import absolute_reverse

from osf.models import AbstractNode, Comment, Preprint, Guid, DraftRegistration
Expand Down Expand Up @@ -172,6 +172,13 @@ class MaxSizePagination(JSONAPIPagination):
max_page_size = None
page_size_query_param = None


class ElasticsearchListViewPagination(JSONAPIPagination):
page_size = MAX_SIZE_OF_ES_QUERY
max_page_size = MAX_SIZE_OF_ES_QUERY
page_size_query_param = None


class NoMaxPageSizePagination(JSONAPIPagination):
max_page_size = None

Expand Down
1 change: 1 addition & 0 deletions api/base/settings/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@

MAX_SIZE_OF_ES_QUERY = 10000
DEFAULT_ES_NULL_VALUE = 'N/A'
REPORT_FILENAME_FORMAT = '{view_name}_{date_created}.{format_type}'

CI_ENV = False

Expand Down
6 changes: 6 additions & 0 deletions api/institutions/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from api.base.settings import DEFAULT_ES_NULL_VALUE
from api.metrics.permissions import IsInstitutionalMetricsUser
from api.metrics.renderers import MetricsReportsCsvRenderer, MetricsReportsTsvRenderer
from api.nodes.serializers import NodeSerializer
from api.nodes.filters import NodesFilterMixin
from api.users.serializers import UserSerializer
Expand Down Expand Up @@ -554,6 +555,11 @@ class _NewInstitutionUserMetricsList(InstitutionMixin, ElasticsearchListView):

view_category = 'institutions'
view_name = 'institution-user-metrics'
renderer_classes = (
*api_settings.DEFAULT_RENDERER_CLASSES,
MetricsReportsCsvRenderer,
MetricsReportsTsvRenderer,
)

serializer_class = NewInstitutionUserMetricsSerializer

Expand Down
36 changes: 30 additions & 6 deletions api/metrics/renderers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import csv
import io
import csv
import datetime
from api.base.settings.defaults import REPORT_FILENAME_FORMAT

from django.http import Http404

Expand Down Expand Up @@ -42,10 +44,19 @@ def get_csv_row(keys_list, report_attrs):
]


class MetricsReportsCsvRenderer(renderers.BaseRenderer):
media_type = 'text/csv'
format = 'csv'
CSV_DIALECT = csv.excel
class MetricsReportsRenderer(renderers.BaseRenderer):

def get_filename(self, renderer_context: dict, format_type: str) -> str:
"""Generate the filename for the file based on format_type REPORT_FILENAME_FORMAT and current date."""
if renderer_context and 'view' in renderer_context:
current_date = datetime.datetime.now().strftime('%Y-%m')
return REPORT_FILENAME_FORMAT.format(
view_name=renderer_context['view'].view_name,
date_created=current_date,
format_type=format_type,
)
else:
raise NotImplementedError('Missing format filename')

def render(self, json_response, accepted_media_type=None, renderer_context=None):
serialized_reports = (
Expand All @@ -64,10 +75,23 @@ def render(self, json_response, accepted_media_type=None, renderer_context=None)
csv_writer.writerow(
get_csv_row(csv_fieldnames, serialized_report),
)

response = renderer_context['response']
filename = self.get_filename(renderer_context, self.extension)
response['Content-Disposition'] = f'attachment; filename="{filename}"'

return csv_filecontent.getvalue()


class MetricsReportsTsvRenderer(MetricsReportsCsvRenderer):
class MetricsReportsTsvRenderer(MetricsReportsRenderer):
format = 'tsv'
extension = 'tsv'
media_type = 'text/tab-separated-values'
CSV_DIALECT = csv.excel_tab


class MetricsReportsCsvRenderer(MetricsReportsRenderer):
format = 'csv'
extension = 'csv'
media_type = 'text/csv'
CSV_DIALECT = csv.excel
170 changes: 168 additions & 2 deletions api_tests/institutions/views/test_institution_user_metric_list.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import datetime
import csv
import datetime
from io import StringIO
from random import random
from urllib.parse import urlencode

import pytest
from waffle.testutils import override_flag

from api.base.settings.defaults import API_BASE, DEFAULT_ES_NULL_VALUE
from api.base.settings.defaults import API_BASE, DEFAULT_ES_NULL_VALUE, REPORT_FILENAME_FORMAT
import osf.features
from osf_tests.factories import (
InstitutionFactory,
Expand Down Expand Up @@ -404,6 +404,172 @@ def test_paginate_reports(self, app, url, institutional_admin, institution, repo
assert _resp.status_code == 200
assert list(_user_ids(_resp)) == _expected_user_id_list

@pytest.mark.parametrize('format_type, delimiter, content_type', [
('csv', ',', 'text/csv; charset=utf-8'),
('tsv', '\t', 'text/tab-separated-values; charset=utf-8')
])
def test_get_report_formats_csv_tsv(self, app, url, institutional_admin, institution, format_type, delimiter,
content_type):
_report_factory(
'2024-08',
institution,
user_id='u_orcomma',
account_creation_date='2018-02',
user_name='Jason Kelce',
orcid_id='4444-3333-2222-1111',
department_name='Center, \t Greatest Ever',
storage_byte_count=736662999298,
embargoed_registration_count=1,
published_preprint_count=1,
public_registration_count=2,
public_project_count=3,
public_file_count=4,
private_project_count=5,
month_last_active='2018-02',
month_last_login='2018-02',
)

resp = app.get(f'{url}?format={format_type}', auth=institutional_admin.auth)
assert resp.status_code == 200
assert resp.headers['Content-Type'] == content_type

current_date = datetime.datetime.now().strftime('%Y-%m')
expected_filename = REPORT_FILENAME_FORMAT.format(
view_name='institution-user-metrics',
date_created=current_date,
format_type=format_type
)
assert resp.headers['Content-Disposition'] == f'attachment; filename="{expected_filename}"'

response_body = resp.text
expected_response = [
[
'account_creation_date',
'department',
'embargoed_registration_count',
'month_last_active',
'month_last_login',
'orcid_id',
'private_projects',
'public_file_count',
'public_projects',
'public_registration_count',
'published_preprint_count',
'storage_byte_count',
'user_name'
],
[
'2018-02',
'Center, \t Greatest Ever',
'1',
'2018-02',
'2018-02',
'4444-3333-2222-1111',
'5',
'4',
'3',
'2',
'1',
'736662999298',
'Jason Kelce'
]
]

if delimiter:
with StringIO(response_body) as file:
reader = csv.reader(file, delimiter=delimiter)
response_rows = list(reader)
assert response_rows[0] == expected_response[0]
assert sorted(response_rows[1:]) == sorted(expected_response[1:])

@pytest.mark.parametrize('format_type, delimiter, content_type', [
('csv', ',', 'text/csv; charset=utf-8'),
('tsv', '\t', 'text/tab-separated-values; charset=utf-8')
])
def test_csv_tsv_ignores_pagination(self, app, url, institutional_admin, institution, format_type, delimiter,
content_type):
# Create 15 records, exceeding the default page size of 10
num_records = 15
expected_data = []
for i in range(num_records):
_report_factory(
'2024-08',
institution,
user_id=f'u_orcomma_{i}',
account_creation_date='2018-02',
user_name=f'Jalen Hurts #{i}',
orcid_id=f'4444-3333-2222-111{i}',
department_name='QBatman',
storage_byte_count=736662999298 + i,
embargoed_registration_count=1,
published_preprint_count=1,
public_registration_count=2,
public_project_count=3,
public_file_count=4,
private_project_count=5,
month_last_active='2018-02',
month_last_login='2018-02',
)
expected_data.append([
'2018-02',
'QBatman',
'1',
'2018-02',
'2018-02',
f'4444-3333-2222-111{i}',
'5',
'4',
'3',
'2',
'1',
str(736662999298 + i),
f'Jalen Hurts #{i}',
])

# Make request for CSV format with page[size]=10
resp = app.get(f'{url}?format={format_type}', auth=institutional_admin.auth)
assert resp.status_code == 200
assert resp.headers['Content-Type'] == content_type

current_date = datetime.datetime.now().strftime('%Y-%m')
expected_filename = REPORT_FILENAME_FORMAT.format(
view_name='institution-user-metrics',
date_created=current_date,
format_type=format_type
)
assert resp.headers['Content-Disposition'] == f'attachment; filename="{expected_filename}"'

# Validate the CSV content contains all 15 records, ignoring the default pagination of 10
response_body = resp.text
rows = response_body.splitlines()

assert len(rows) == num_records + 1 == 16 # 1 header + 15 records

if delimiter:
with StringIO(response_body) as file:
reader = csv.reader(file, delimiter=delimiter)
response_rows = list(reader)
# Validate header row
expected_header = [
'account_creation_date',
'department',
'embargoed_registration_count',
'month_last_active',
'month_last_login',
'orcid_id',
'private_projects',
'public_file_count',
'public_projects',
'public_registration_count',
'published_preprint_count',
'storage_byte_count',
'user_name'
]
assert response_rows[0] == expected_header
# Sort both expected and actual rows (ignoring the header) before comparison
assert sorted(response_rows[1:]) == sorted(expected_data)


def _user_ids(api_response):
for _datum in api_response.json['data']:
yield _datum['relationships']['user']['data']['id']
Expand Down

0 comments on commit 822317e

Please sign in to comment.