Skip to content

Commit

Permalink
fix some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yu23ki14 committed Feb 6, 2025
1 parent 81056ba commit 91702f0
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 107 deletions.
52 changes: 37 additions & 15 deletions api/birdxplorer_api/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
UserId,
)
from birdxplorer_common.storage import Storage
from urllib.parse import urlencode

PostsPaginationMetaWithExamples: TypeAlias = Annotated[
PaginationMeta,
Expand Down Expand Up @@ -275,7 +276,11 @@ def str_to_twitter_timestamp(s: str) -> TwitterTimestamp:


def ensure_twitter_timestamp(t: Union[str, TwitterTimestamp]) -> TwitterTimestamp:
return str_to_twitter_timestamp(t) if isinstance(t, str) else t
try:
timestamp = str_to_twitter_timestamp(t) if isinstance(t, str) else t
return timestamp
except:
raise ValueError(f"Timestamp out of range")


def gen_router(storage: Storage) -> APIRouter:
Expand Down Expand Up @@ -314,10 +319,13 @@ def get_notes(
language: Union[LanguageIdentifier, None] = Query(default=None, **V1DataNotesDocs.params["language"]),
search_text: Union[None, str] = Query(default=None, **V1DataNotesDocs.params["search_text"]),
) -> NoteListResponse:
if created_at_from is not None and isinstance(created_at_from, str):
created_at_from = ensure_twitter_timestamp(created_at_from)
if created_at_to is not None and isinstance(created_at_to, str):
created_at_to = ensure_twitter_timestamp(created_at_to)
try:
if created_at_from is not None and isinstance(created_at_from, str):
created_at_from = ensure_twitter_timestamp(created_at_from)
if created_at_to is not None and isinstance(created_at_to, str):
created_at_to = ensure_twitter_timestamp(created_at_to)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))

notes = list(
storage.get_notes(
Expand Down Expand Up @@ -374,10 +382,14 @@ def get_posts(
search_url: Union[None, HttpUrl] = Query(default=None, **V1DataPostsDocs.params["search_url"]),
media: bool = Query(default=True, **V1DataPostsDocs.params["media"]),
) -> PostListResponse:
if created_at_from is not None and isinstance(created_at_from, str):
created_at_from = ensure_twitter_timestamp(created_at_from)
if created_at_to is not None and isinstance(created_at_to, str):
created_at_to = ensure_twitter_timestamp(created_at_to)
try:
if created_at_from is not None and isinstance(created_at_from, str):
created_at_from = ensure_twitter_timestamp(created_at_from)
if created_at_to is not None and isinstance(created_at_to, str):
created_at_to = ensure_twitter_timestamp(created_at_to)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))

posts = list(
storage.get_posts(
post_ids=post_ids,
Expand Down Expand Up @@ -450,10 +462,13 @@ def search(
limit: int = Query(default=100, gt=0, le=1000, **V1DataSearchDocs.params["limit"]),
) -> SearchResponse:
# Convert timestamp strings to TwitterTimestamp objects
if note_created_at_from is not None and isinstance(note_created_at_from, str):
note_created_at_from = ensure_twitter_timestamp(note_created_at_from)
if note_created_at_to is not None and isinstance(note_created_at_to, str):
note_created_at_to = ensure_twitter_timestamp(note_created_at_to)
try:
if note_created_at_from is not None and isinstance(note_created_at_from, str):
note_created_at_from = ensure_twitter_timestamp(note_created_at_from)
if note_created_at_to is not None and isinstance(note_created_at_to, str):
note_created_at_to = ensure_twitter_timestamp(note_created_at_to)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))

# Get search results using the optimized storage method
results = []
Expand Down Expand Up @@ -512,14 +527,21 @@ def search(

# Generate pagination URLs
base_url = str(request.url).split("?")[0]
query_params = dict(request.query_params)
next_offset = offset + limit
prev_offset = max(offset - limit, 0)

next_url = None
if next_offset < total_count:
next_url = f"{base_url}?offset={next_offset}&limit={limit}"
query_params["offset"] = next_offset
query_params["limit"] = limit
next_url = f"{base_url}?{urlencode(query_params)}"

prev_url = None
if offset > 0:
prev_url = f"{base_url}?offset={prev_offset}&limit={limit}"
query_params["offset"] = prev_offset
query_params["limit"] = limit
prev_url = f"{base_url}?{urlencode(query_params)}"

return SearchResponse(data=results, meta=PaginationMeta(next=next_url, prev=prev_url))

Expand Down
2 changes: 1 addition & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ dev=[
"httpx",
]
prod=[
"birdxplorer_common @ git+https://github.com/codeforjapan/BirdXplorer.git@feature/138#subdirectory=common",
"birdxplorer_common @ git+https://github.com/codeforjapan/BirdXplorer.git@main#subdirectory=common",
"psycopg2",
"gunicorn",
]
Expand Down
198 changes: 107 additions & 91 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Generator, List, Tuple, Union
from typing import Any, Generator, List, Tuple, Union

from psycopg2.extensions import AsIs, register_adapter
from pydantic import AnyUrl, HttpUrl
from sqlalchemy import ForeignKey, create_engine, func, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.orm.query import RowReturningQuery
from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String, Uuid

from .models import BinaryBool, LanguageIdentifier
Expand Down Expand Up @@ -507,6 +508,72 @@ def get_number_of_posts(
)
return query.count()

def _apply_filters(
self,
query: RowReturningQuery[Tuple[Any, ...]],
note_includes_text: Union[str, None] = None,
note_excludes_text: Union[str, None] = None,
post_includes_text: Union[str, None] = None,
post_excludes_text: Union[str, None] = None,
language: Union[LanguageIdentifier, None] = None,
topic_ids: Union[List[TopicId], None] = None,
note_status: Union[List[str], None] = None,
note_created_at_from: Union[TwitterTimestamp, None] = None,
note_created_at_to: Union[TwitterTimestamp, None] = None,
x_user_names: Union[List[str], None] = None,
x_user_followers_count_from: Union[int, None] = None,
x_user_follow_count_from: Union[int, None] = None,
post_like_count_from: Union[int, None] = None,
post_repost_count_from: Union[int, None] = None,
post_impression_count_from: Union[int, None] = None,
post_includes_media: Union[bool, None] = None,
) -> RowReturningQuery[Tuple[Any, ...]]:
# Apply note filters
if note_includes_text:
query = query.filter(NoteRecord.summary.like(f"%{note_includes_text}%"))
if note_excludes_text:
query = query.filter(~NoteRecord.summary.like(f"%{note_excludes_text}%"))
if language:
query = query.filter(NoteRecord.language == language)
if topic_ids:
subq = (
select(NoteTopicAssociation.note_id)
.filter(NoteTopicAssociation.topic_id.in_(topic_ids))
.group_by(NoteTopicAssociation.note_id)
.subquery()
)
query = query.join(subq, NoteRecord.note_id == subq.c.note_id)
if note_status:
query = query.filter(NoteRecord.current_status.in_(note_status))
if note_created_at_from:
query = query.filter(NoteRecord.created_at >= note_created_at_from)
if note_created_at_to:
query = query.filter(NoteRecord.created_at <= note_created_at_to)

# Apply post filters
if post_includes_text:
query = query.filter(PostRecord.text.like(f"%{post_includes_text}%"))
if post_excludes_text:
query = query.filter(~PostRecord.text.like(f"%{post_excludes_text}%"))
if x_user_names:
query = query.filter(XUserRecord.name.in_(x_user_names))
if x_user_followers_count_from:
query = query.filter(XUserRecord.followers_count >= x_user_followers_count_from)
if x_user_follow_count_from:
query = query.filter(XUserRecord.following_count >= x_user_follow_count_from)
if post_like_count_from:
query = query.filter(PostRecord.like_count >= post_like_count_from)
if post_repost_count_from:
query = query.filter(PostRecord.repost_count >= post_repost_count_from)
if post_impression_count_from:
query = query.filter(PostRecord.impression_count >= post_impression_count_from)
if post_includes_media:
query = query.filter(PostRecord.media_details.any())
elif post_includes_media is False:
query = query.filter(~PostRecord.media_details.any())

return query

def search_notes_with_posts(
self,
note_includes_text: Union[str, None] = None,
Expand All @@ -529,61 +596,34 @@ def search_notes_with_posts(
limit: int = 100,
) -> Generator[Tuple[NoteModel, PostModel | None], None, None]:
with Session(self.engine) as sess:
# Base query joining notes, posts and users
query = (
sess.query(NoteRecord, PostRecord)
.outerjoin(PostRecord, NoteRecord.post_id == PostRecord.post_id)
.outerjoin(XUserRecord, PostRecord.user_id == XUserRecord.user_id)
)

# Apply note filters
if note_includes_text:
query = query.filter(NoteRecord.summary.like(f"%{note_includes_text}%"))
if note_excludes_text:
query = query.filter(~NoteRecord.summary.like(f"%{note_excludes_text}%"))
if language:
query = query.filter(NoteRecord.language == language)
if topic_ids:
subq = (
select(NoteTopicAssociation.note_id)
.filter(NoteTopicAssociation.topic_id.in_(topic_ids))
.group_by(NoteTopicAssociation.note_id)
.subquery()
)
query = query.join(subq, NoteRecord.note_id == subq.c.note_id)
if note_status:
query = query.filter(NoteRecord.current_status.in_(note_status))
if note_created_at_from:
query = query.filter(NoteRecord.created_at >= note_created_at_from)
if note_created_at_to:
query = query.filter(NoteRecord.created_at <= note_created_at_to)

# Apply post filters
if post_includes_text:
query = query.filter(PostRecord.text.like(f"%{post_includes_text}%"))
if post_excludes_text:
query = query.filter(~PostRecord.text.like(f"%{post_excludes_text}%"))
if x_user_names:
query = query.filter(XUserRecord.name.in_(x_user_names))
if x_user_followers_count_from:
query = query.filter(XUserRecord.followers_count >= x_user_followers_count_from)
if x_user_follow_count_from:
query = query.filter(XUserRecord.following_count >= x_user_follow_count_from)
if post_like_count_from:
query = query.filter(PostRecord.like_count >= post_like_count_from)
if post_repost_count_from:
query = query.filter(PostRecord.repost_count >= post_repost_count_from)
if post_impression_count_from:
query = query.filter(PostRecord.impression_count >= post_impression_count_from)
if post_includes_media:
query = query.filter(PostRecord.media_details.any())
if post_includes_media is False:
query = query.filter(~PostRecord.media_details.any())

# Pagination
query = self._apply_filters(
query,
note_includes_text,
note_excludes_text,
post_includes_text,
post_excludes_text,
language,
topic_ids,
note_status,
note_created_at_from,
note_created_at_to,
x_user_names,
x_user_followers_count_from,
x_user_follow_count_from,
post_like_count_from,
post_repost_count_from,
post_impression_count_from,
post_includes_media,
)

query = query.offset(offset).limit(limit)

# Execute query and yield results
for note_record, post_record in query.all():
note = NoteModel(
note_id=note_record.note_id,
Expand Down Expand Up @@ -627,49 +667,25 @@ def count_search_results(
.outerjoin(XUserRecord, PostRecord.user_id == XUserRecord.user_id)
)

# Apply note filters
if note_includes_text:
query = query.filter(NoteRecord.summary.like(f"%{note_includes_text}%"))
if note_excludes_text:
query = query.filter(~NoteRecord.summary.like(f"%{note_excludes_text}%"))
if language:
query = query.filter(NoteRecord.language == language)
if topic_ids:
subq = (
select(NoteTopicAssociation.note_id)
.filter(NoteTopicAssociation.topic_id.in_(topic_ids))
.group_by(NoteTopicAssociation.note_id)
.subquery()
)
query = query.join(subq, NoteRecord.note_id == subq.c.note_id)
if note_status:
query = query.filter(NoteRecord.current_status.in_(note_status))
if note_created_at_from:
query = query.filter(NoteRecord.created_at >= note_created_at_from)
if note_created_at_to:
query = query.filter(NoteRecord.created_at <= note_created_at_to)

# Apply post filters
if post_includes_text:
query = query.filter(PostRecord.text.like(f"%{post_includes_text}%"))
if post_excludes_text:
query = query.filter(~PostRecord.text.like(f"%{post_excludes_text}%"))
if x_user_names:
query = query.filter(XUserRecord.name.in_(x_user_names))
if x_user_followers_count_from:
query = query.filter(XUserRecord.followers_count >= x_user_followers_count_from)
if x_user_follow_count_from:
query = query.filter(XUserRecord.following_count >= x_user_follow_count_from)
if post_like_count_from:
query = query.filter(PostRecord.like_count >= post_like_count_from)
if post_repost_count_from:
query = query.filter(PostRecord.repost_count >= post_repost_count_from)
if post_impression_count_from:
query = query.filter(PostRecord.impression_count >= post_impression_count_from)
if post_includes_media:
query = query.filter(PostRecord.media_details.any())
elif post_includes_media is False:
query = query.filter(~PostRecord.media_details.any())
query = self._apply_filters(
query,
note_includes_text,
note_excludes_text,
post_includes_text,
post_excludes_text,
language,
topic_ids,
note_status,
note_created_at_from,
note_created_at_to,
x_user_names,
x_user_followers_count_from,
x_user_follow_count_from,
post_like_count_from,
post_repost_count_from,
post_impression_count_from,
post_includes_media,
)

return query.scalar() or 0

Expand Down

0 comments on commit 91702f0

Please sign in to comment.