Skip to content

Commit

Permalink
feat: implement get_number_of_posts
Browse files Browse the repository at this point in the history
  • Loading branch information
osoken committed Aug 18, 2024
1 parent 48c2967 commit 64e1dfd
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
16 changes: 15 additions & 1 deletion common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,21 @@ def get_number_of_posts(
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
) -> int:
raise NotImplementedError
with Session(self.engine) as sess:
query = sess.query(PostRecord)
if post_ids is not None:
query = query.filter(PostRecord.post_id.in_(post_ids))
if note_ids is not None:
query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter(
NoteRecord.note_id.in_(note_ids)
)
if start is not None:
query = query.filter(PostRecord.created_at >= start)
if end is not None:
query = query.filter(PostRecord.created_at < end)
if search_text is not None:
query = query.filter(PostRecord.text.like(f"%{search_text}%"))
return query.count()


def gen_storage(settings: GlobalSettings) -> Storage:
Expand Down
28 changes: 28 additions & 0 deletions common/tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,34 @@ def test_get_post(
assert expected == actual


@pytest.mark.parametrize(
["filter_args", "expected_indices"],
[
[dict(), [0, 1, 2]],
[dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]],
[dict(post_ids=[]), []],
[dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]],
[dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2]],
[dict(end=TwitterTimestamp.from_int(1153921700000)), [0]],
[dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]],
[dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]],
],
)
def test_get_number_of_posts(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
topic_records_sample: List[TopicRecord],
note_records_sample: List[NoteRecord],
filter_args: Dict[str, Any],
expected_indices: List[int],
) -> None:
storage = Storage(engine=engine_for_test)
actual = storage.get_number_of_posts(**filter_args)
expected = len(expected_indices)
assert expected == actual


def test_get_notes_by_ids(
engine_for_test: Engine,
note_samples: List[Note],
Expand Down

0 comments on commit 64e1dfd

Please sign in to comment.