Skip to content

Commit

Permalink
Updated the read_media function to be able to extract information eve…
Browse files Browse the repository at this point in the history
…n without stakeholder_id
  • Loading branch information
bannontan committed Aug 6, 2024
1 parent aa28499 commit bbaed2d
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions python-server/app/qdrant_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,32 @@ def rank_ids_qdrant(query_vector, media_ids, limit=5):
return search_result

# function to read media
def read_media(stakeholder_id, query: str=None):
def read_media(stakeholder_id: int=None, query: str=None):
with Session(media_engine) as db:
media_ids = get_media_ids_for_stakeholder(db, int(float(stakeholder_id)))
#Apply filter if query is defined
if query:
if stakeholder_id is not None:
media_ids = get_media_ids_for_stakeholder(db, int(float(stakeholder_id)))
#Apply filter if query is defined
if query:
query_vector = vectorize_query(query)
hits = rank_ids_qdrant(query_vector, media_ids, limit=2)
# media_ids = [hit.id for hit in hits]

elif stakeholder_id is None:
query_vector = vectorize_query(query)
hits = rank_ids_qdrant(query_vector, media_ids, limit=2)
media_ids = [hit.id for hit in hits]
hits = qdrant_client.search(
collection_name="media_collection",
query_vector=query_vector,
with_payload=True,
limit=2
)

media_ids = [hit.id for hit in hits]
# Get content from media ids
articles = get_content_from_media_ids(db, media_ids)
response = '\n'.join(articles)
return response
response = '\n'.join(articles)
return response

def derive_rs_from_media(stakeholder_id, query: str=None):
def derive_rs_from_media(stakeholder_id: int=None, query: str=None):

page_content=read_media(stakeholder_id, query)

Expand All @@ -56,5 +68,6 @@ def derive_rs_from_media(stakeholder_id, query: str=None):
if __name__ == "__main__":
# model = ChatVertexAI(model="gemini-1.5-flash", max_retries=2)
query = "Joe Biden supporters"
ls = derive_rs_from_media(stakeholder_id=28235, query=query)
# stakeholder_id=28235
ls = derive_rs_from_media(query=query)
print(ls)

0 comments on commit bbaed2d

Please sign in to comment.