Skip to content

Commit

Permalink
Speed up property function
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Apr 23, 2024
1 parent 45abff6 commit 9b4ea42
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
50 changes: 35 additions & 15 deletions abcd/backends/atoms_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@

from ase import Atoms
from ase.io import iread
from opensearchpy import OpenSearch, helpers, AuthenticationException, ConnectionTimeout
from opensearchpy import (
OpenSearch,
helpers,
AuthenticationException,
ConnectionTimeout,
RequestError,
)

from abcd.backends import utils
from abcd.database import AbstractABCD
Expand Down Expand Up @@ -533,8 +539,7 @@ def count(self, query: Union[dict, str, None] = None) -> int:
def property(self, name, query: Union[dict, str, None] = None) -> list:
"""
Gets all values of a specified property for matching documents in the
database. This method is very slow, so it is preferable to use
alternative methods where possible, such as count_property.
database. Alternative methods, such as count_property, may be faster.
Parameters
----------
Expand All @@ -543,24 +548,39 @@ def property(self, name, query: Union[dict, str, None] = None) -> list:
Returns
-------
List of values for the specified property for all matching documents.
list
List of values for the specified property for all matching documents.
"""
query = self.parser(query)
query = {
"query": query,
}

return [
hit["_source"][format(name)]
for hit in helpers.scan(
self.client,
index=self.index_name,
query=query,
stored_fields=format(name),
_source=format(name),
)
if format(name) in hit["_source"]
]
try:
return [
hit["fields"][format(name)][0]
for hit in helpers.scan(
self.client,
index=self.index_name,
query=query,
_source=False,
stored_fields="_none_",
docvalue_fields=[format(name)],
)
if "fields" in hit and format(name) in hit["fields"]
]
except RequestError:
return [
hit["_source"][format(name)]
for hit in helpers.scan(
self.client,
index=self.index_name,
query=query,
stored_fields=format(name),
_source=format(name),
)
if format(name) in hit["_source"]
]

def count_property(self, name, query: Union[dict, str, None] = None) -> dict:
"""
Expand Down
13 changes: 9 additions & 4 deletions tests/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_property(self):

xyz_1 = StringIO(
"""2
Properties=species:S:1:pos:R:3 s="sadf" prop_1="test_1" pbc="F F F"
Properties=species:S:1:pos:R:3 energy=-5.0 prop_1="test_1"
Si 0.00000000 0.00000000 0.00000000
Si 0.00000000 0.00000000 0.00000000
"""
Expand All @@ -244,11 +244,11 @@ def test_property(self):
atoms_1 = read(xyz_1, format="extxyz")
assert isinstance(atoms_1, Atoms)
atoms_1.set_cell([1, 1, 1])
self.abcd.push(atoms_1)
self.abcd.push(atoms_1, store_calc=False)

xyz_2 = StringIO(
"""2
Properties=species:S:1:pos:R:3 s="sadf" prop_2="test_2" pbc="F F F"
Properties=species:S:1:pos:R:3 energy=-10.0 prop_2="test_2"
Si 0.00000000 0.00000000 0.00000000
Si 0.00000000 0.00000000 0.00000000
"""
Expand All @@ -257,13 +257,18 @@ def test_property(self):
atoms_2 = read(xyz_2, format="extxyz")
assert isinstance(atoms_2, Atoms)
atoms_2.set_cell([1, 1, 1])
self.abcd.push(atoms_2)
self.abcd.push(atoms_2, store_calc=False)

self.abcd.refresh()
prop = self.abcd.property("prop_1")
expected_prop = ["test_1"]
self.assertEqual(prop, expected_prop)

prop = self.abcd.property("energy")
expected_prop = [-5.0, -10.0]
self.assertEqual(prop[0], expected_prop[0])
self.assertEqual(prop[1], expected_prop[1])

def test_properties(self):
"""
Test getting all properties from the database.
Expand Down

0 comments on commit 9b4ea42

Please sign in to comment.