From dd67204088adfff5ef7af43f0048c2a91c786538 Mon Sep 17 00:00:00 2001 From: ElliottKasoar Date: Tue, 23 Apr 2024 17:38:12 +0000 Subject: [PATCH] Speed up property function --- abcd/backends/atoms_opensearch.py | 44 ++++++++++++++++++++----------- tests/opensearch.py | 9 +++++-- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/abcd/backends/atoms_opensearch.py b/abcd/backends/atoms_opensearch.py index e7721975..0add0c4a 100644 --- a/abcd/backends/atoms_opensearch.py +++ b/abcd/backends/atoms_opensearch.py @@ -9,7 +9,7 @@ from ase import Atoms from ase.io import iread -from opensearchpy import OpenSearch, helpers, AuthenticationException, ConnectionTimeout +from opensearchpy import OpenSearch, helpers, AuthenticationException, ConnectionTimeout, RequestError from abcd.backends import utils from abcd.database import AbstractABCD @@ -533,8 +533,7 @@ def count(self, query: Union[dict, str, None] = None) -> int: def property(self, name, query: Union[dict, str, None] = None) -> list: """ Gets all values of a specified property for matching documents in the - database. This method is very slow, so it is preferable to use - alternative methods where possible, such as count_property. + database. Alternative methods, such as count_property, may be faster. Parameters ---------- @@ -543,24 +542,39 @@ def property(self, name, query: Union[dict, str, None] = None) -> list: Returns ------- - List of values for the specified property for all matching documents. + list + List of values for the specified property for all matching documents. """ query = self.parser(query) query = { "query": query, } - return [ - hit["_source"][format(name)] - for hit in helpers.scan( - self.client, - index=self.index_name, - query=query, - stored_fields=format(name), - _source=format(name), - ) - if format(name) in hit["_source"] - ] + try: + return [ + hit["fields"][format(name)][0] + for hit in helpers.scan( + self.client, + index=self.index_name, + query=query, + _source=False, + stored_fields="_none_", + docvalue_fields=[format(name)], + ) + if "fields" in hit and format(name) in hit["fields"] + ] + except RequestError: + return [ + hit["_source"][format(name)] + for hit in helpers.scan( + self.client, + index=self.index_name, + query=query, + stored_fields=format(name), + _source=format(name), + ) + if format(name) in hit["_source"] + ] def count_property(self, name, query: Union[dict, str, None] = None) -> dict: """ diff --git a/tests/opensearch.py b/tests/opensearch.py index 259536dc..63ee01bf 100644 --- a/tests/opensearch.py +++ b/tests/opensearch.py @@ -235,7 +235,7 @@ def test_property(self): xyz_1 = StringIO( """2 - Properties=species:S:1:pos:R:3 s="sadf" prop_1="test_1" pbc="F F F" + Properties=species:S:1:pos:R:3 energy=-5.0 prop_1="test_1" Si 0.00000000 0.00000000 0.00000000 Si 0.00000000 0.00000000 0.00000000 """ @@ -248,7 +248,7 @@ def test_property(self): xyz_2 = StringIO( """2 - Properties=species:S:1:pos:R:3 s="sadf" prop_2="test_2" pbc="F F F" + Properties=species:S:1:pos:R:3 energy=-10.0 prop_2="test_2" Si 0.00000000 0.00000000 0.00000000 Si 0.00000000 0.00000000 0.00000000 """ @@ -264,6 +264,11 @@ def test_property(self): expected_prop = ["test_1"] self.assertEqual(prop, expected_prop) + prop = self.abcd.property("energy") + expected_prop = [-5.0, -10.0] + self.assertEqual(prop[0], expected_prop[0]) + self.assertEqual(prop[1], expected_prop[1]) + def test_properties(self): """ Test getting all properties from the database.