Skip to content

Commit

Permalink
Speed up property function
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Apr 23, 2024
1 parent 45abff6 commit dd67204
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
44 changes: 29 additions & 15 deletions abcd/backends/atoms_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ase import Atoms
from ase.io import iread
from opensearchpy import OpenSearch, helpers, AuthenticationException, ConnectionTimeout
from opensearchpy import OpenSearch, helpers, AuthenticationException, ConnectionTimeout, RequestError

from abcd.backends import utils
from abcd.database import AbstractABCD
Expand Down Expand Up @@ -533,8 +533,7 @@ def count(self, query: Union[dict, str, None] = None) -> int:
def property(self, name, query: Union[dict, str, None] = None) -> list:
"""
Gets all values of a specified property for matching documents in the
database. This method is very slow, so it is preferable to use
alternative methods where possible, such as count_property.
database. Alternative methods, such as count_property, may be faster.
Parameters
----------
Expand All @@ -543,24 +542,39 @@ def property(self, name, query: Union[dict, str, None] = None) -> list:
Returns
-------
List of values for the specified property for all matching documents.
list
List of values for the specified property for all matching documents.
"""
query = self.parser(query)
query = {
"query": query,
}

return [
hit["_source"][format(name)]
for hit in helpers.scan(
self.client,
index=self.index_name,
query=query,
stored_fields=format(name),
_source=format(name),
)
if format(name) in hit["_source"]
]
try:
return [
hit["fields"][format(name)][0]
for hit in helpers.scan(
self.client,
index=self.index_name,
query=query,
_source=False,
stored_fields="_none_",
docvalue_fields=[format(name)],
)
if "fields" in hit and format(name) in hit["fields"]
]
except RequestError:
return [
hit["_source"][format(name)]
for hit in helpers.scan(
self.client,
index=self.index_name,
query=query,
stored_fields=format(name),
_source=format(name),
)
if format(name) in hit["_source"]
]

def count_property(self, name, query: Union[dict, str, None] = None) -> dict:
"""
Expand Down
9 changes: 7 additions & 2 deletions tests/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_property(self):

xyz_1 = StringIO(
"""2
Properties=species:S:1:pos:R:3 s="sadf" prop_1="test_1" pbc="F F F"
Properties=species:S:1:pos:R:3 energy=-5.0 prop_1="test_1"
Si 0.00000000 0.00000000 0.00000000
Si 0.00000000 0.00000000 0.00000000
"""
Expand All @@ -248,7 +248,7 @@ def test_property(self):

xyz_2 = StringIO(
"""2
Properties=species:S:1:pos:R:3 s="sadf" prop_2="test_2" pbc="F F F"
Properties=species:S:1:pos:R:3 energy=-10.0 prop_2="test_2"
Si 0.00000000 0.00000000 0.00000000
Si 0.00000000 0.00000000 0.00000000
"""
Expand All @@ -264,6 +264,11 @@ def test_property(self):
expected_prop = ["test_1"]
self.assertEqual(prop, expected_prop)

prop = self.abcd.property("energy")
expected_prop = [-5.0, -10.0]
self.assertEqual(prop[0], expected_prop[0])
self.assertEqual(prop[1], expected_prop[1])

def test_properties(self):
"""
Test getting all properties from the database.
Expand Down

0 comments on commit dd67204

Please sign in to comment.