Skip to content

Commit

Permalink
Better aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Dec 11, 2024
1 parent 7f9ad28 commit 1d5490e
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 14 deletions.
2 changes: 2 additions & 0 deletions chat/src/agent/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
match output.name:
case "aggregate":
self.socket.send({"type": "aggregation_result", "ref": self.ref, "message": output.artifact.get("aggregation_result", {})})
case "discover_fields":
pass
case "search":
try:
docs: List[Dict[str, Any]] = [doc.metadata for doc in output.artifact]
Expand Down
4 changes: 2 additions & 2 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Literal

from agent.dynamodb_saver import DynamoDBSaver
from agent.tools import search, aggregate
from agent.tools import aggregate, discover_fields, search
from langchain_core.messages.base import BaseMessage
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from setup import openai_chat_client

tools = [search, aggregate]
tools = [discover_fields, search, aggregate]

tool_node = ToolNode(tools)

Expand Down
50 changes: 40 additions & 10 deletions chat/src/agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,63 @@
from langchain_core.tools import tool
from setup import opensearch_vector_store

def get_keyword_fields(properties, prefix=''):
keyword_fields = []
for field_name, field_mapping in properties.items():
current_path = f"{prefix}{field_name}"
if field_mapping.get('type') == 'keyword':
keyword_fields.append(current_path)
if 'fields' in field_mapping:
for subfield_name, subfield_mapping in field_mapping['fields'].items():
if subfield_mapping.get('type') == 'keyword':
keyword_fields.append(f"{current_path}.{subfield_name}")
if 'properties' in field_mapping:
nested_properties = field_mapping['properties']
keyword_fields.extend(get_keyword_fields(nested_properties, prefix=current_path + '.'))
return keyword_fields

@tool(response_format="content_and_artifact")
def discover_fields():
"""
Discover the fields available in the OpenSearch index. This tool is useful for understanding the structure of the index and the fields available for aggregation queries.
"""
# filter fields that are not useful for aggregation (only include keyword fields)
opensearch = opensearch_vector_store()
fields = opensearch.client.indices.get_mapping(index=opensearch.index)
top_properties = list(fields.values())[0]['mappings']['properties']
result = get_keyword_fields(top_properties)

return json.dumps(result, default=str), result

@tool(response_format="content_and_artifact")
def search(query: str):
"""Perform a semantic search of Northwestern University Library digital collections. When answering a search query, ground your answer in the context of the results with references to the document's metadata."""
query_results = opensearch_vector_store().similarity_search(query, size=20)
return json.dumps(query_results, default=str), query_results

@tool(response_format="content_and_artifact")
def aggregate(aggregation_query: str):
def aggregate(agg_field: str, term_field: str, term: str):
"""
Perform a quantitative aggregation on the OpenSearch index.
Args:
agg_field (str): The field to aggregate on.
term_field (str): The field to filter on.
term (str): The term to filter on.
Leave term_field and term empty to aggregate across the entire index.
Available fields:
api_link, api_model, ark, collection.title.keyword, contributor.label.keyword, contributor.variants,
create_date, creator.variants, date_created, embedding_model, embedding_text_length,
folder_name, folder_number, genre.variants, id, identifier, indexed_at, language.variants,
legacy_identifier, library_unit, location.variants, modified_date, notes.note, notes.type,
physical_description_material, physical_description_size, preservation_level, provenance, published, publisher,
related_url.url, related_url.label, representative_file_set.aspect_ratio, representative_file_set.url, rights_holder,
series, status, style_period.label.keyword, style_period.variants, subject.label.keyword, subject.role,
subject.variants, table_of_contents, technique.label.keyword, technique.variants, title.keyword, visibility, work_type
You must use the discover_fields tool first to obtain the list of appropriate fields for aggregration in the index.
Do not use any fields that do not exist in the list returned by discover_fields!
Examples:
- Number of collections: collection.title.keyword
- Number of works by work type: work_type
"""
try:
response = opensearch_vector_store().aggregations_search(aggregation_query)
response = opensearch_vector_store().aggregations_search(agg_field, term_field, term)
return json.dumps(response, default=str), response
except Exception as e:
return json.dumps({"error": str(e)}), None
7 changes: 5 additions & 2 deletions chat/src/handlers/opensearch_neural_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ def similarity_search_with_score(

return documents_with_scores

def aggregations_search(self, field: str, **kwargs: Any) -> dict:
def aggregations_search(self, agg_field: str, term_field: str = None, term: str = None, **kwargs: Any) -> dict:
"""Perform a search with aggregations and return the aggregation results."""
query = {"match_all": {}} if (term is None or term == "") else {"match": {term_field: term}}

dsl = {
"size": 0,
"aggs": {"aggregation_result": {"terms": {"field": field}}},
"query": query,
"aggs": {"aggregation_result": {"terms": {"field": agg_field}}},
}

response = self.client.search(
Expand Down

0 comments on commit 1d5490e

Please sign in to comment.