Skip to content

Commit

Permalink
refactor: remove time range related filter functions from the storage…
Browse files Browse the repository at this point in the history
… class
  • Loading branch information
osoken committed Aug 18, 2024
1 parent a7178c8 commit 3c757ef
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 20 deletions.
21 changes: 4 additions & 17 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,26 +282,13 @@ def get_posts(
query = sess.query(PostRecord)
if post_ids is not None:
query = query.filter(PostRecord.post_id.in_(post_ids))
if start is not None:
query = query.filter(PostRecord.created_at >= start)
if end is not None:
query = query.filter(PostRecord.created_at < end)
for post_record in query.all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_range(
self, start: TwitterTimestamp, end: TwitterTimestamp
) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at.between(start, end)).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_start(self, start: TwitterTimestamp) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at >= start).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_end(self, end: TwitterTimestamp) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at < end).all():
yield self._post_record_to_model(post_record)

def get_posts_by_note_ids(self, note_ids: List[NoteId]) -> Generator[PostModel, None, None]:
query = (
select(PostRecord)
Expand Down
6 changes: 3 additions & 3 deletions common/tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_get_posts_by_created_at_range(
start = TwitterTimestamp.from_int(1153921700000)
end = TwitterTimestamp.from_int(1153921800000)
expected = [post_samples[i] for i in (1,)]
actual = list(storage.get_posts_by_created_at_range(start, end))
actual = list(storage.get_posts(start=start, end=end))
assert expected == actual


Expand All @@ -93,7 +93,7 @@ def test_get_posts_by_created_at_start(
storage = Storage(engine=engine_for_test)
start = TwitterTimestamp.from_int(1153921700000)
expected = [post_samples[i] for i in (1, 2)]
actual = list(storage.get_posts_by_created_at_start(start))
actual = list(storage.get_posts(start=start))
assert expected == actual


Expand All @@ -107,7 +107,7 @@ def test_get_posts_by_created_at_end(
storage = Storage(engine=engine_for_test)
end = TwitterTimestamp.from_int(1153921700000)
expected = [post_samples[i] for i in (0,)]
actual = list(storage.get_posts_by_created_at_end(end))
actual = list(storage.get_posts(end=end))
assert expected == actual


Expand Down

0 comments on commit 3c757ef

Please sign in to comment.