Skip to content

Commit

Permalink
Speed up property function
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Apr 23, 2024
1 parent 45abff6 commit 79b4c66
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
44 changes: 29 additions & 15 deletions abcd/backends/atoms_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ase import Atoms
from ase.io import iread
from opensearchpy import OpenSearch, helpers, AuthenticationException, ConnectionTimeout
from opensearchpy import OpenSearch, helpers, AuthenticationException, ConnectionTimeout, RequestError

from abcd.backends import utils
from abcd.database import AbstractABCD
Expand Down Expand Up @@ -533,8 +533,7 @@ def count(self, query: Union[dict, str, None] = None) -> int:
def property(self, name, query: Union[dict, str, None] = None) -> list:
"""
Gets all values of a specified property for matching documents in the
database. This method is very slow, so it is preferable to use
alternative methods where possible, such as count_property.
database. Alternative methods, such as count_property, may be faster.
Parameters
----------
Expand All @@ -543,24 +542,39 @@ def property(self, name, query: Union[dict, str, None] = None) -> list:
Returns
-------
List of values for the specified property for all matching documents.
list
List of values for the specified property for all matching documents.
"""
query = self.parser(query)
query = {
"query": query,
}

return [
hit["_source"][format(name)]
for hit in helpers.scan(
self.client,
index=self.index_name,
query=query,
stored_fields=format(name),
_source=format(name),
)
if format(name) in hit["_source"]
]
try:
return [
hit["field"][format(name)][0]
for hit in helpers.scan(
self.client,
index=self.index_name,
query=query,
_source=False,
stored_fields="_none_",
docvalue_fields=[format(name)],
)
if format(name) in hit["field"]
]
except RequestError:
return [
hit["_source"][format(name)]
for hit in helpers.scan(
self.client,
index=self.index_name,
query=query,
stored_fields=format(name),
_source=format(name),
)
if format(name) in hit["_source"]
]

def count_property(self, name, query: Union[dict, str, None] = None) -> dict:
"""
Expand Down
6 changes: 3 additions & 3 deletions tests/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_property(self):

xyz_1 = StringIO(
"""2
Properties=species:S:1:pos:R:3 s="sadf" prop_1="test_1" pbc="F F F"
Properties=species:S:1:pos:R:3 s="sadf" prop_1=1 pbc="F F F"
Si 0.00000000 0.00000000 0.00000000
Si 0.00000000 0.00000000 0.00000000
"""
Expand All @@ -248,7 +248,7 @@ def test_property(self):

xyz_2 = StringIO(
"""2
Properties=species:S:1:pos:R:3 s="sadf" prop_2="test_2" pbc="F F F"
Properties=species:S:1:pos:R:3 s="sadf" prop_2=2 pbc="F F F"
Si 0.00000000 0.00000000 0.00000000
Si 0.00000000 0.00000000 0.00000000
"""
Expand All @@ -261,7 +261,7 @@ def test_property(self):

self.abcd.refresh()
prop = self.abcd.property("prop_1")
expected_prop = ["test_1"]
expected_prop = [1]
self.assertEqual(prop, expected_prop)

def test_properties(self):
Expand Down

0 comments on commit 79b4c66

Please sign in to comment.