Skip to content

Commit

Permalink
feat(core): add post_id query string option for GET posts endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
osoken committed Mar 24, 2024
1 parent 911efa1 commit 088454f
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 5 deletions.
30 changes: 30 additions & 0 deletions birdxplorer/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import csv
import io
from urllib.parse import parse_qs as parse_query_string
from urllib.parse import urlencode as encode_query_string

from fastapi import FastAPI
from pydantic.alias_generators import to_snake
from starlette.types import ASGIApp, Receive, Scope, Send

from .logger import get_logger
from .routers.data import gen_router as gen_data_router
Expand All @@ -7,10 +14,33 @@
from .storage import gen_storage


class QueryStringFlatteningMiddleware:
def __init__(self, app: ASGIApp) -> None:
self._app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
query_string = scope.get("query_string")
if not isinstance(query_string, bytes):
query_string = b""
query_string = query_string.decode("utf-8")
if scope["type"] == "http" and query_string:
parsed = parse_query_string(query_string)
flattened = {}
for name, values in parsed.items():
flattened[to_snake(name)] = [c for value in values for r in csv.reader(io.StringIO(value)) for c in r]

scope["query_string"] = encode_query_string(flattened, doseq=True).encode("utf-8")

await self._app(scope, receive, send)
else:
await self._app(scope, receive, send)


def gen_app(settings: GlobalSettings) -> FastAPI:
_ = get_logger(level=settings.logger_settings.level)
storage = gen_storage(settings=settings)
app = FastAPI()
app.add_middleware(QueryStringFlatteningMiddleware)
app.include_router(gen_system_router(), prefix="/api/v1/system")
app.include_router(gen_data_router(storage=storage), prefix="/api/v1/data")
return app
10 changes: 6 additions & 4 deletions birdxplorer/routers/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List
from typing import List, Union

from fastapi import APIRouter
from fastapi import APIRouter, Query

from ..models import BaseModel, ParticipantId, Post, Topic, UserEnrollment
from ..models import BaseModel, ParticipantId, Post, PostId, Topic, UserEnrollment
from ..storage import Storage


Expand All @@ -29,7 +29,9 @@ def get_topics() -> TopicListResponse:
return TopicListResponse(data=list(storage.get_topics()))

@router.get("/posts", response_model=PostListResponse)
def get_posts() -> PostListResponse:
def get_posts(post_id: Union[List[PostId], None] = Query(default=None)) -> PostListResponse:
if post_id is not None:
return PostListResponse(data=list(storage.get_posts_by_ids(post_ids=post_id)))
return PostListResponse(data=list(storage.get_posts()))

return router
5 changes: 4 additions & 1 deletion birdxplorer/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ParticipantId,
)
from .models import Post as PostModel
from .models import SummaryString
from .models import PostId, SummaryString
from .models import Topic as TopicModel
from .models import (
TopicId,
Expand Down Expand Up @@ -153,6 +153,9 @@ def get_posts(self) -> Generator[PostModel, None, None]:
impression_count=post_record.impression_count,
)

def get_posts_by_ids(self, post_ids: List[PostId]) -> Generator[PostModel, None, None]:
raise NotImplementedError


def gen_storage(settings: GlobalSettings) -> Storage:
engine = create_engine(settings.storage_settings.sqlalchemy_database_url)
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Note,
ParticipantId,
Post,
PostId,
Topic,
TwitterTimestamp,
UserEnrollment,
Expand Down Expand Up @@ -135,6 +136,15 @@ def _get_posts() -> Generator[Post, None, None]:

mock.get_posts.side_effect = _get_posts

def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]:
for i in post_ids:
for post in post_samples:
if post.post_id == i:
yield post
break

mock.get_posts_by_ids.side_effect = _get_posts_by_ids

yield mock


Expand Down
9 changes: 9 additions & 0 deletions tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,12 @@ def test_posts_get(client: TestClient, post_samples: List[Post]) -> None:
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(d.model_dump_json()) for d in post_samples]}


def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Post]) -> None:
response = client.get(f"/api/v1/data/posts/?postId={post_samples[0].post_id},{post_samples[2].post_id}")
assert response.status_code == 200
res_json = response.json()
assert res_json == {
"data": [json.loads(post_samples[0].model_dump_json()), json.loads(post_samples[2].model_dump_json())]
}

0 comments on commit 088454f

Please sign in to comment.