diff --git a/bento_reference_service/routers/refget.py b/bento_reference_service/routers/refget.py index ab60131..1c3a405 100644 --- a/bento_reference_service/routers/refget.py +++ b/bento_reference_service/routers/refget.py @@ -1,5 +1,7 @@ import io import math +import orjson +import typing from bento_lib.service_info.helpers import build_service_type, build_service_info_from_pydantic_config from bento_lib.service_info.types import GA4GHServiceInfo @@ -7,6 +9,7 @@ from fastapi import APIRouter, HTTPException, Request, Response, status from fastapi.responses import StreamingResponse from pydantic import BaseModel +from typing import Literal from .. import models, streaming as s, __version__ from ..authz import authz_middleware @@ -25,29 +28,53 @@ REFGET_VERSION = "2.0.0" REFGET_SERVICE_TYPE = build_service_type("org.ga4gh", "refget", REFGET_VERSION) +REFGET_CHARSET = "us-ascii" + REFGET_HEADER_TEXT = f"text/vnd.ga4gh.refget.v{REFGET_VERSION}+plain" -REFGET_HEADER_TEXT_WITH_CHARSET = f"{REFGET_HEADER_TEXT}; charset=us-ascii" +REFGET_HEADER_TEXT_WITH_CHARSET = f"{REFGET_HEADER_TEXT}; charset={REFGET_CHARSET}" REFGET_HEADER_JSON = f"application/vnd.ga4gh.refget.v{REFGET_VERSION}+json" -REFGET_HEADER_JSON_WITH_CHARSET = f"{REFGET_HEADER_JSON}; charset=us-ascii" +REFGET_HEADER_JSON_WITH_CHARSET = f"{REFGET_HEADER_JSON}; charset={REFGET_CHARSET}" + + +class RefGetJSONResponse(Response): + media_type = REFGET_HEADER_JSON + charset = REFGET_CHARSET + + def render(self, content: typing.Any) -> bytes: + return orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS) + refget_router = APIRouter(prefix="/sequence") -@refget_router.get("/service-info", dependencies=[authz_middleware.dep_public_endpoint()]) -async def refget_service_info( - config: ConfigDependency, logger: LoggerDependency, request: Request, response: Response -) -> dict: - accept_header: str | None = request.headers.get("Accept", None) - if accept_header and accept_header not in ( - REFGET_HEADER_JSON_WITH_CHARSET, - REFGET_HEADER_JSON, - "application/json", - "application/*", - "*/*", - ): +def check_accept_header(accept_header: str | None, mode: Literal["text", "json"]) -> None: + valid_header_values = ( + ( + REFGET_HEADER_TEXT_WITH_CHARSET, + REFGET_HEADER_TEXT, + "text/plain", + "text/*", + "*/*", + ) + if mode == "text" + else ( + REFGET_HEADER_JSON_WITH_CHARSET, + REFGET_HEADER_JSON, + "application/json", + "application/*", + "*/*", + ) + ) + + if accept_header and accept_header not in valid_header_values: raise HTTPException(status_code=status.HTTP_406_NOT_ACCEPTABLE, detail="Not Acceptable") - response.headers["Content-Type"] = REFGET_HEADER_JSON_WITH_CHARSET + +@refget_router.get("/service-info", dependencies=[authz_middleware.dep_public_endpoint()]) +async def refget_service_info( + config: ConfigDependency, logger: LoggerDependency, request: Request +) -> RefGetJSONResponse: + check_accept_header(request.headers.get("Accept"), mode="json") genome_service_info: GA4GHServiceInfo = await build_service_info_from_pydantic_config( config, logger, {}, REFGET_SERVICE_TYPE, __version__ @@ -55,17 +82,19 @@ async def refget_service_info( del genome_service_info["bento"] - return { - **genome_service_info, - "refget": { - "circular_supported": False, - # I don't like that they used the word 'subsequence' here... that's not what that means exactly. - # It's a substring! - "subsequence_limit": config.response_substring_limit, - "algorithms": ["md5", "ga4gh"], - "identifier_types": [], - }, - } + return RefGetJSONResponse( + { + **genome_service_info, + "refget": { + "circular_supported": False, + # I don't like that they used the word 'subsequence' here... that's not what that means exactly. + # It's a substring! + "subsequence_limit": config.response_substring_limit, + "algorithms": ["md5", "ga4gh"], + "identifier_types": [], + }, + } + ) REFGET_BAD_REQUEST = Response(status_code=status.HTTP_400_BAD_REQUEST, content=b"Bad Request") @@ -87,16 +116,11 @@ async def refget_sequence( ): headers = {"Content-Type": REFGET_HEADER_TEXT_WITH_CHARSET, "Accept-Ranges": "bytes"} - accept_header: str | None = request.headers.get("Accept", None) - if accept_header and accept_header not in ( - REFGET_HEADER_TEXT_WITH_CHARSET, - REFGET_HEADER_TEXT, - "text/plain", - "text/*", - "*/*", - ): - logger.error(f"not acceptable: bad Accept header value") - return Response(status_code=status.HTTP_406_NOT_ACCEPTABLE, content=b"Not Acceptable") + try: + check_accept_header(request.headers.get("Accept"), mode="text") + except HTTPException as e: + logger.error(f"not acceptable: bad Accept header value") # don't log actual value to prevent log injection + return Response(status_code=e.status_code, content=e.detail.encode("ascii")) # Don't use FastAPI's auto-Header tool for the Range header # 'cause I don't want to shadow Python's range() function @@ -219,18 +243,22 @@ class RefGetSequenceMetadataResponse(BaseModel): metadata: RefGetSequenceMetadata -@refget_router.get("/{sequence_checksum}/metadata", dependencies=[authz_middleware.dep_public_endpoint()]) +@refget_router.get( + "/{sequence_checksum}/metadata", + dependencies=[authz_middleware.dep_public_endpoint()], + responses={ + status.HTTP_200_OK: {REFGET_HEADER_JSON: {"schema": RefGetSequenceMetadataResponse.model_json_schema()}} + }, +) async def refget_sequence_metadata( - db: DatabaseDependency, - response: Response, - sequence_checksum: str, -) -> RefGetSequenceMetadataResponse: + db: DatabaseDependency, request: Request, sequence_checksum: str +) -> RefGetJSONResponse: + check_accept_header(request.headers.get("Accept"), mode="json") + res: tuple[str, models.ContigWithRefgetURI] | None = await db.get_genome_and_contig_by_checksum_str( sequence_checksum ) - response.headers["Content-Type"] = REFGET_HEADER_JSON_WITH_CHARSET - if res is None: # TODO: proper 404 for refget spec # TODO: proper content type for exception - RefGet error class? @@ -240,11 +268,13 @@ async def refget_sequence_metadata( ) contig = res[1] - return RefGetSequenceMetadataResponse( - metadata=RefGetSequenceMetadata( - md5=contig.md5, - ga4gh=contig.ga4gh, - length=contig.length, - aliases=contig.aliases, - ), + return RefGetJSONResponse( + RefGetSequenceMetadataResponse( + metadata=RefGetSequenceMetadata( + md5=contig.md5, + ga4gh=contig.ga4gh, + length=contig.length, + aliases=contig.aliases, + ), + ).model_dump(mode="json"), ) diff --git a/tests/conftest.py b/tests/conftest.py index 8960efa..e1324ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,9 @@ from bento_reference_service.logger import get_logger from bento_reference_service.main import app +from .shared_data import TEST_GENOME_SARS_COV_2 +from .shared_functions import create_genome_with_permissions + @pytest.fixture() def config() -> Config: @@ -99,3 +102,9 @@ def test_client(db: Database): def aioresponse(): with aioresponses() as m: yield m + + +@pytest.fixture +def sars_cov_2_genome(test_client: TestClient, aioresponse: aioresponses, db_cleanup): + create_genome_with_permissions(test_client, aioresponse, TEST_GENOME_SARS_COV_2) + return TEST_GENOME_SARS_COV_2 diff --git a/tests/test_genome_routes.py b/tests/test_genome_routes.py index 3b31863..6b6c4d1 100644 --- a/tests/test_genome_routes.py +++ b/tests/test_genome_routes.py @@ -98,10 +98,7 @@ async def test_genome_create(test_client: TestClient, aioresponse: aioresponses, assert len(res.json()) == 2 -async def test_genome_detail_endpoints(test_client: TestClient, aioresponse: aioresponses, db_cleanup): - # setup: create genome TODO: fixture - create_covid_genome_with_permissions(test_client, aioresponse) - +async def test_genome_detail_endpoints(test_client: TestClient, sars_cov_2_genome): # tests res = test_client.get(f"/genomes/{SARS_COV_2_GENOME_ID}") @@ -194,10 +191,7 @@ async def test_genome_without_gff3_and_then_patch(test_client: TestClient, aiore assert res.status_code == status.HTTP_200_OK -async def test_genome_delete(test_client: TestClient, aioresponse: aioresponses, db_cleanup): - # setup: create genome TODO: fixture - create_covid_genome_with_permissions(test_client, aioresponse) - +async def test_genome_delete(test_client: TestClient, sars_cov_2_genome, aioresponse: aioresponses): aioresponse.post("https://authz.local/policy/evaluate", payload={"result": [[True]]}) res = test_client.delete(f"/genomes/{SARS_COV_2_GENOME_ID}", headers=AUTHORIZATION_HEADER) assert res.status_code == status.HTTP_204_NO_CONTENT diff --git a/tests/test_refget.py b/tests/test_refget.py index e3b7621..1a12de7 100644 --- a/tests/test_refget.py +++ b/tests/test_refget.py @@ -1,11 +1,9 @@ import pysam -from aioresponses import aioresponses from fastapi import status from fastapi.testclient import TestClient -from .shared_data import SARS_COV_2_FASTA_PATH, TEST_GENOME_SARS_COV_2 -from .shared_functions import create_genome_with_permissions +from .shared_data import SARS_COV_2_FASTA_PATH REFGET_2_0_0_TYPE = {"group": "org.ga4gh", "artifact": "refget", "version": "2.0.0"} @@ -18,6 +16,7 @@ def test_refget_service_info(test_client: TestClient, db_cleanup): rd = res.json() assert res.status_code == status.HTTP_200_OK + assert res.headers["content-type"] == "application/vnd.ga4gh.refget.v2.0.0+json" assert "id" in rd assert "name" in rd @@ -37,59 +36,54 @@ def test_refget_sequence_not_found(test_client: TestClient, db_cleanup): assert res.status_code == status.HTTP_404_NOT_FOUND -def test_refget_sequence_invalid_requests(test_client: TestClient, aioresponse: aioresponses, db_cleanup): - # TODO: fixture - create_genome_with_permissions(test_client, aioresponse, TEST_GENOME_SARS_COV_2) - test_contig = TEST_GENOME_SARS_COV_2["contigs"][0] +def test_refget_sequence_invalid_requests(test_client: TestClient, sars_cov_2_genome): + test_contig = sars_cov_2_genome["contigs"][0] + seq_url = f"/sequence/{test_contig['md5']}" # ------------------------------------------------------------------------------------------------------------------ # cannot return HTML - res = test_client.get(f"/sequence/{test_contig['md5']}", headers={"Accept": "text/html"}) + res = test_client.get(seq_url, headers={"Accept": "text/html"}) assert res.status_code == status.HTTP_406_NOT_ACCEPTABLE assert res.content == b"Not Acceptable" # cannot have start > end - res = test_client.get( - f"/sequence/{test_contig['md5']}", params={"start": 5, "end": 1}, headers=HEADERS_ACCEPT_PLAIN - ) + res = test_client.get(seq_url, params={"start": 5, "end": 1}, headers=HEADERS_ACCEPT_PLAIN) assert res.status_code == status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE assert res.content == b"Range Not Satisfiable" # start > contig length (by 1 base, since it's 0-based) - res = test_client.get( - f"/sequence/{test_contig['md5']}", params={"start": test_contig["length"]}, headers=HEADERS_ACCEPT_PLAIN - ) + res = test_client.get(seq_url, params={"start": test_contig["length"]}, headers=HEADERS_ACCEPT_PLAIN) assert res.status_code == status.HTTP_400_BAD_REQUEST assert res.content == b"Bad Request" # end > contig length (by 1 base, since it's 0-based exclusive) - res = test_client.get( - f"/sequence/{test_contig['md5']}", params={"end": test_contig["length"] + 1}, headers=HEADERS_ACCEPT_PLAIN - ) + res = test_client.get(seq_url, params={"end": test_contig["length"] + 1}, headers=HEADERS_ACCEPT_PLAIN) assert res.status_code == status.HTTP_400_BAD_REQUEST assert res.content == b"Bad Request" # bad range header - res = test_client.get(f"/sequence/{test_contig['md5']}", headers={"Range": "dajkshfasd", **HEADERS_ACCEPT_PLAIN}) + res = test_client.get(seq_url, headers={"Range": "dajkshfasd", **HEADERS_ACCEPT_PLAIN}) assert res.status_code == status.HTTP_400_BAD_REQUEST assert res.content == b"Bad Request" # cannot have range header and start/end res = test_client.get( - f"/sequence/{test_contig['md5']}", + seq_url, params={"start": "0", "end": "11"}, headers={"Range": "bytes=0-10", **HEADERS_ACCEPT_PLAIN}, ) assert res.status_code == status.HTTP_400_BAD_REQUEST assert res.content == b"Bad Request" + # cannot have overlaps in range header + res = test_client.get(seq_url, headers={"Range": "bytes=0-10, 5-15", **HEADERS_ACCEPT_PLAIN}) + assert res.status_code == status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE + assert res.content == b"Range Not Satisfiable" -def test_refget_sequence_full(test_client: TestClient, aioresponse: aioresponses, db_cleanup): - # TODO: fixture - create_genome_with_permissions(test_client, aioresponse, TEST_GENOME_SARS_COV_2) - test_contig = TEST_GENOME_SARS_COV_2["contigs"][0] +def test_refget_sequence_full(test_client: TestClient, sars_cov_2_genome): + test_contig = sars_cov_2_genome["contigs"][0] # Load COVID contig bytes rf = pysam.FastaFile(str(SARS_COV_2_FASTA_PATH)) @@ -114,11 +108,8 @@ def test_refget_sequence_full(test_client: TestClient, aioresponse: aioresponses assert res.content == seq -def test_refget_sequence_partial(test_client, aioresponse: aioresponses, db_cleanup): - # TODO: fixture - create_genome_with_permissions(test_client, aioresponse, TEST_GENOME_SARS_COV_2) - - test_contig = TEST_GENOME_SARS_COV_2["contigs"][0] +def test_refget_sequence_partial(test_client, sars_cov_2_genome): + test_contig = sars_cov_2_genome["contigs"][0] seq_url = f"/sequence/{test_contig['md5']}" # Load COVID contig bytes @@ -161,3 +152,33 @@ def _check_first_10(r, sc, ar="none"): res = test_client.get(seq_url, headers={"Range": "bytes=-10", **HEADERS_ACCEPT_PLAIN}) assert res.status_code == status.HTTP_206_PARTIAL_CONTENT assert res.content == seq[-10:] + + +def test_refget_metadata(test_client: TestClient, sars_cov_2_genome): + test_contig = sars_cov_2_genome["contigs"][0] + seq_m_url = f"/sequence/{test_contig['md5']}/metadata" + + # ------------------------------------------------------------------------------------------------------------------ + + res = test_client.get(seq_m_url) + assert res.status_code == status.HTTP_200_OK + assert res.headers["content-type"] == "application/vnd.ga4gh.refget.v2.0.0+json" + assert res.json() == { + "metadata": { + "md5": test_contig["md5"], + "ga4gh": test_contig["ga4gh"], + "length": test_contig["length"], + "aliases": test_contig["aliases"], + } + } + + +def test_refget_metadata_406(test_client: TestClient, sars_cov_2_genome): + res = test_client.get(f"/sequence/{sars_cov_2_genome['contigs'][0]['md5']}/metadata", headers=HEADERS_ACCEPT_PLAIN) + assert res.status_code == status.HTTP_406_NOT_ACCEPTABLE + + +def test_refget_metadata_404(test_client: TestClient): + res = test_client.get("/sequence/does-not-exist/metadata") + # TODO: proper content type for exception - RefGet error class? + assert res.status_code == status.HTTP_404_NOT_FOUND