From c88b811bf3f4e79488b9a9965d050e88eb658a03 Mon Sep 17 00:00:00 2001 From: Brendan Quinn Date: Wed, 18 Dec 2024 22:41:29 +0000 Subject: [PATCH] Full test coverage for tools --- chat/test/agent/test_tools.py | 164 ++++++++++++++++++++++++++++------ 1 file changed, 139 insertions(+), 25 deletions(-) diff --git a/chat/test/agent/test_tools.py b/chat/test/agent/test_tools.py index 81217de..18e5811 100644 --- a/chat/test/agent/test_tools.py +++ b/chat/test/agent/test_tools.py @@ -1,51 +1,165 @@ from unittest import TestCase +from unittest.mock import patch, MagicMock +import json import sys sys.path.append('./src') -from agent.tools import get_keyword_fields +from agent.tools import discover_fields, search, aggregate, get_keyword_fields from test.fixtures.opensearch import TOP_PROPERTIES + class TestTools(TestCase): + @patch('agent.tools.opensearch_vector_store') + def test_discover_fields(self, mock_opensearch): + # Mock the OpenSearch response + mock_client = MagicMock() + mock_client.indices.get_mapping.return_value = { + 'index_name': { + 'mappings': { + 'properties': { + 'field1': {'type': 'keyword'}, + 'field3': { + 'properties': { + 'subfield1': {'type': 'keyword'} + } + } + } + } + } + } + mock_opensearch.return_value.client = mock_client + + # Pass required parameters based on the tool's schema + response = discover_fields.invoke({"query": ""}) # Assuming query is the required parameter + self.assertIsInstance(response, str) + parsed_response = json.loads(response) + self.assertEqual(parsed_response, ["field1", "field3.subfield1"]) + + @patch('agent.tools.opensearch_vector_store') + def test_search(self, mock_opensearch): + mock_results = [{"id": "doc1", "text": "example result"}] + mock_opensearch.return_value.similarity_search.return_value = mock_results + + response = search.invoke("test query") + self.assertIsInstance(response, str) # Remove .content check + parsed_response = json.loads(response) + self.assertEqual(parsed_response, mock_results) + + @patch('agent.tools.opensearch_vector_store') + def test_aggregate(self, mock_opensearch): + mock_response = { + "aggregations": { + "example_agg": { + "buckets": [] + } + } + } + mock_opensearch.return_value.aggregations_search.return_value = mock_response + + # Pass parameters directly instead of as JSON string + response = aggregate.invoke({ + "agg_field": "field1", + "term_field": "term_field", + "term": "term" + }) + self.assertIsInstance(response, str) + parsed_response = json.loads(response) + self.assertEqual(parsed_response, mock_response) + + @patch('agent.tools.opensearch_vector_store') + def test_aggregate_no_term(self, mock_opensearch): + mock_response = { + "aggregations": { + "all_docs": { + "buckets": [] + } + } + } + mock_opensearch.return_value.aggregations_search.return_value = mock_response + + response = aggregate.invoke({ + "agg_field": "field1", + "term_field": "", + "term": "" + }) + self.assertIsInstance(response, str) + parsed_response = json.loads(response) + self.assertEqual(parsed_response, mock_response) + def test_get_keyword_fields(self): properties = { - "field1": {"type": "keyword"}, - "field2": {"type": "text"}, - "field3": { - "type": "object", - "properties": { - "subfield1": {"type": "keyword"}, - "subfield2": {"type": "text"} + 'field1': {'type': 'keyword'}, + 'field2': {'type': 'text'}, + 'field3': { + 'fields': { + 'raw': {'type': 'keyword'} } } } result = get_keyword_fields(properties) - self.assertEqual(result, ["field1", "field3.subfield1"]) - + self.assertEqual(set(result), {'field1', 'field3.raw'}) + def test_nested_get_keyword_fields(self): properties = { - "field1": {"type": "keyword"}, - "field2": {"type": "text"}, - "field3": { - "type": "object", - "properties": { - "subfield1": {"type": "keyword"}, - "subfield2": {"type": "text"}, - "subfield3": { - "type": "object", - "properties": { - "subsubfield1": {"type": "keyword"}, - "subsubfield2": {"type": "text"} + 'field1': { + 'properties': { + 'nested1': {'type': 'keyword'}, + 'nested2': {'type': 'text'} + } + } + } + result = get_keyword_fields(properties) + self.assertEqual(result, ['field1.nested1']) + + def test_complex_mapping_get_keyword_fields(self): + properties = { + 'field1': { + 'properties': { + 'nested1': { + 'type': 'keyword' + }, + 'nested2': { + 'properties': { + 'subnested1': {'type': 'keyword'} } } } + }, + 'field2': { + 'fields': { + 'raw': {'type': 'keyword'} + } } } result = get_keyword_fields(properties) - self.assertEqual(result, ["field1", "field3.subfield1", "field3.subfield3.subsubfield1"]) + self.assertEqual(set(result), {'field1.nested1', 'field1.nested2.subnested1', 'field2.raw'}) - def test_complex_mapping_get_keyword_fields(self): + def test_meadow_mapping_get_keyword_fields(self): result = get_keyword_fields(TOP_PROPERTIES) + expected_result = ['accession_number', 'all_controlled_terms', 'all_ids', 'api_link', 'api_model', 'ark', 'batch_ids', 'box_name', 'box_number', 'canonical_link', 'catalog_key', 'collection.id', 'collection.title.keyword', 'contributor.facet', 'contributor.id', 'contributor.label', 'contributor.label_with_role', 'contributor.role', 'contributor.variants', 'creator.facet', 'creator.id', 'creator.label', 'creator.variants', 'csv_metadata_update_jobs', 'date_created', 'date_created_edtf', 'embedding_model', 'file_sets.accession_number', 'file_sets.download_url', 'file_sets.id', 'file_sets.label', 'file_sets.mime_type', 'file_sets.original_filename', 'file_sets.representative_image_url', 'file_sets.role', 'file_sets.streaming_url', 'file_sets.webvtt', 'folder_name', 'folder_number', 'genre.facet', 'genre.id', 'genre.label', 'genre.label_with_role', 'genre.role', 'genre.variants', 'id', 'identifier', 'iiif_manifest', 'ingest_project.id', 'ingest_project.title', 'ingest_sheet.id', 'ingest_sheet.title', 'keywords', 'language.facet', 'language.id', 'language.label', 'language.variants', 'legacy_identifier', 'library_unit', 'license.id', 'license.label', 'location.facet', 'location.id', 'location.label', 'location.variants', 'notes.note', 'notes.type', 'physical_description_material', 'physical_description_size', 'preservation_level', 'project.cycle', 'project.desc', 'project.manager', 'project.name', 'project.proposer', 'project.task_number', 'provenance', 'publisher', 'related_material', 'related_url.label', 'related_url.url', 'representative_file_set.id', 'representative_file_set.url', 'rights_holder', 'rights_statement.id', 'rights_statement.label', 'scope_and_contents', 'series', 'source', 'status', 'style_period.facet', 'style_period.id', 'style_period.label', 'style_period.variants', 'subject.facet', 'subject.id', 'subject.label', 'subject.label_with_role', 'subject.role', 'subject.variants', 'technique.facet', 'technique.id', 'technique.label', 'technique.variants', 'terms_of_use', 'thumbnail', 'title.keyword', 'visibility', 'work_type'] self.assertEqual( result, - ['accession_number', 'all_controlled_terms', 'all_ids', 'api_link', 'api_model', 'ark', 'batch_ids', 'box_name', 'box_number', 'canonical_link', 'catalog_key', 'collection.id', 'collection.title.keyword', 'contributor.facet', 'contributor.id', 'contributor.label', 'contributor.label_with_role', 'contributor.role', 'contributor.variants', 'creator.facet', 'creator.id', 'creator.label', 'creator.variants', 'csv_metadata_update_jobs', 'date_created', 'date_created_edtf', 'embedding_model', 'file_sets.accession_number', 'file_sets.download_url', 'file_sets.id', 'file_sets.label', 'file_sets.mime_type', 'file_sets.original_filename', 'file_sets.representative_image_url', 'file_sets.role', 'file_sets.streaming_url', 'file_sets.webvtt', 'folder_name', 'folder_number', 'genre.facet', 'genre.id', 'genre.label', 'genre.label_with_role', 'genre.role', 'genre.variants', 'id', 'identifier', 'iiif_manifest', 'ingest_project.id', 'ingest_project.title', 'ingest_sheet.id', 'ingest_sheet.title', 'keywords', 'language.facet', 'language.id', 'language.label', 'language.variants', 'legacy_identifier', 'library_unit', 'license.id', 'license.label', 'location.facet', 'location.id', 'location.label', 'location.variants', 'notes.note', 'notes.type', 'physical_description_material', 'physical_description_size', 'preservation_level', 'project.cycle', 'project.desc', 'project.manager', 'project.name', 'project.proposer', 'project.task_number', 'provenance', 'publisher', 'related_material', 'related_url.label', 'related_url.url', 'representative_file_set.id', 'representative_file_set.url', 'rights_holder', 'rights_statement.id', 'rights_statement.label', 'scope_and_contents', 'series', 'source', 'status', 'style_period.facet', 'style_period.id', 'style_period.label', 'style_period.variants', 'subject.facet', 'subject.id', 'subject.label', 'subject.label_with_role', 'subject.role', 'subject.variants', 'technique.facet', 'technique.id', 'technique.label', 'technique.variants', 'terms_of_use', 'thumbnail', 'title.keyword', 'visibility', 'work_type']) \ No newline at end of file + expected_result) + + @patch('agent.tools.opensearch_vector_store') + def test_aggregate_exception(self, mock_opensearch): + # Configure the mock to raise an exception + mock_client = MagicMock() + mock_client.aggregations_search.side_effect = Exception("Test error") + mock_opensearch.return_value = mock_client + + # Call aggregate with some parameters + response = aggregate.invoke({ + "agg_field": "field1", + "term_field": "term_field", + "term": "term" + }) + + # Verify the response contains the error message + self.assertIsInstance(response, str) + parsed_response = json.loads(response) + self.assertEqual(parsed_response, {"error": "Test error"}) + + # Verify the mock was called with expected parameters + mock_client.aggregations_search.assert_called_once() \ No newline at end of file