From 8ec74c411c2d867a03ec685d6bc7045871f5fa0a Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Wed, 12 Jun 2024 17:01:15 +0100 Subject: [PATCH] Update opensearch tests for pytest --- tests/{opensearch.py => test_opensearch.py} | 356 +++++++++----------- 1 file changed, 165 insertions(+), 191 deletions(-) rename tests/{opensearch.py => test_opensearch.py} (62%) diff --git a/tests/opensearch.py b/tests/test_opensearch.py similarity index 62% rename from tests/opensearch.py rename to tests/test_opensearch.py index 23a56dd4..a9907c03 100644 --- a/tests/opensearch.py +++ b/tests/test_opensearch.py @@ -2,64 +2,55 @@ import logging import os from time import sleep -import unittest from ase.atoms import Atoms from ase.io import read from opensearchpy.exceptions import ConnectionError +import pytest from abcd import ABCD from abcd.backends.atoms_opensearch import AtomsModel, OpenSearchDatabase +NOT_GTHUB_ACTIONS = True +if os.getenv("GITHUB_ACTIONS") == "true": + NOT_GTHUB_ACTIONS = False -class OpenSearch(unittest.TestCase): - """ - Testing live OpenSearch database functions. - """ +@pytest.mark.skipif(NOT_GTHUB_ACTIONS, reason="Not running via GitHub Actions") +class TestOpenSearch: + """Testing live OpenSearch database functions.""" - @classmethod - def setUpClass(cls): - """ - Set up OpenSearch database connection. - """ - if os.getenv("GITHUB_ACTIONS") != "true": - raise unittest.SkipTest("Only runs via GitHub Actions") - cls.security_enabled = os.getenv("security_enabled") == "true" - cls.port = int(os.environ["port"]) - cls.host = "localhost" + @pytest.fixture(autouse=True) + def abcd(self): + """Set up OpenSearch database connection.""" + security_enabled = os.getenv("security_enabled") == "true" + self.port = int(os.environ["port"]) + self.host = "localhost" if os.environ["opensearch-version"] == "latest": - cls.credential = "admin:myStrongPassword123!" + credential = "admin:myStrongPassword123!" else: - cls.credential = "admin:admin" + credential = "admin:admin" logging.basicConfig(level=logging.INFO) - url = f"opensearch://{cls.credential}@{cls.host}:{cls.port}" + url = f"opensearch://{credential}@{self.host}:{self.port}" try: - abcd = ABCD.from_url( + abcd_opensearch = ABCD.from_url( url, index_name="test_index", - use_ssl=cls.security_enabled, + use_ssl=security_enabled, ) except (ConnectionError, ConnectionResetError): sleep(10) - abcd = ABCD.from_url( + abcd_opensearch = ABCD.from_url( url, index_name="test_index", - use_ssl=cls.security_enabled, + use_ssl=security_enabled, ) - assert isinstance(abcd, OpenSearchDatabase) - cls.abcd = abcd + assert isinstance(abcd_opensearch, OpenSearchDatabase) + return abcd_opensearch - @classmethod - def tearDownClass(cls): - """ - Delete index from OpenSearch database. - """ - cls.abcd.destroy() - - def push_data(self): + def push_data(self, abcd): """ Helper function to upload an example xyz file to the database. """ @@ -74,17 +65,17 @@ def push_data(self): atoms = read(xyz, format="extxyz") assert isinstance(atoms, Atoms) atoms.set_cell([1, 1, 1]) - self.abcd.push(atoms) - self.abcd.refresh() + abcd.push(atoms) + abcd.refresh() - def test_info(self): + def test_info(self, abcd): """ Test printing database info. """ - self.abcd.destroy() - self.abcd.create() - self.abcd.refresh() - self.abcd.print_info() + abcd.destroy() + abcd.create() + abcd.refresh() + abcd.print_info() info = { "host": self.host, @@ -94,36 +85,36 @@ def test_info(self): "number of confs": 0, "type": "opensearch", } - self.assertEqual(self.abcd.info(), info) + assert abcd.info() == info - def test_destroy(self): + def test_destroy(self, abcd): """ Test destroying database index. """ - self.abcd.destroy() - self.abcd.create() - self.abcd.refresh() - self.assertTrue(self.abcd.client.indices.exists("test_index")) + abcd.destroy() + abcd.create() + abcd.refresh() + assert abcd.client.indices.exists("test_index") is True - self.abcd.destroy() - self.assertFalse(self.abcd.client.indices.exists("test_index")) + abcd.destroy() + assert abcd.client.indices.exists("test_index") is False - def test_create(self): + def test_create(self, abcd): """ Test creating database index. """ - self.abcd.destroy() - self.abcd.create() - self.abcd.refresh() - self.assertTrue(self.abcd.client.indices.exists("test_index")) - self.assertFalse(self.abcd.client.indices.exists("fake_index")) + abcd.destroy() + abcd.create() + abcd.refresh() + assert abcd.client.indices.exists("test_index") is True + assert abcd.client.indices.exists("fake_index") is False - def test_push(self): + def test_push(self, abcd): """ Test pushing atoms objects to database individually. """ - self.abcd.destroy() - self.abcd.create() + abcd.destroy() + abcd.create() xyz_1 = StringIO( """2 Properties=species:S:1:pos:R:3 s="sadf" _vtk_test="t _ e s t" pbc="F F F" @@ -134,7 +125,7 @@ def test_push(self): atoms_1 = read(xyz_1, format="extxyz") assert isinstance(atoms_1, Atoms) atoms_1.set_cell([1, 1, 1]) - self.abcd.push(atoms_1) + abcd.push(atoms_1) xyz_2 = StringIO( """2 @@ -147,34 +138,34 @@ def test_push(self): assert isinstance(atoms_2, Atoms) atoms_2.set_cell([1, 1, 1]) - self.abcd.refresh() + abcd.refresh() result = AtomsModel( None, None, - self.abcd.client.search(index="test_index")["hits"]["hits"][0]["_source"], + abcd.client.search(index="test_index")["hits"]["hits"][0]["_source"], ).to_ase() - self.assertEqual(atoms_1, result) - self.assertNotEqual(atoms_2, result) + assert atoms_1 == result + assert atoms_2 != result - def test_delete(self): + def test_delete(self, abcd): """ Test deleting all documents from database. """ - self.push_data() - self.push_data() + self.push_data(abcd) + self.push_data(abcd) - self.assertEqual(self.abcd.count(), 2) - self.abcd.delete() - self.assertTrue(self.abcd.client.indices.exists("test_index")) - self.abcd.refresh() - self.assertEqual(self.abcd.count(), 0) + assert abcd.count() == 2 + abcd.delete() + assert abcd.client.indices.exists("test_index") is True + abcd.refresh() + assert abcd.count() == 0 - def test_bulk(self): + def test_bulk(self, abcd): """ Test pushing atoms object to database together. """ - self.abcd.destroy() - self.abcd.create() + abcd.destroy() + abcd.create() xyz_1 = StringIO( """2 Properties=species:S:1:pos:R:3 s="sadf" _vtk_test="t _ e s t" pbc="F F F" @@ -199,39 +190,39 @@ def test_bulk(self): atoms_list = [] atoms_list.append(atoms_1) atoms_list.append(atoms_2) - self.abcd.push(atoms_list) + abcd.push(atoms_list) - self.abcd.refresh() - self.assertEqual(self.abcd.count(), 2) + abcd.refresh() + assert abcd.count() == 2 result_1 = AtomsModel( None, None, - self.abcd.client.search(index="test_index")["hits"]["hits"][0]["_source"], + abcd.client.search(index="test_index")["hits"]["hits"][0]["_source"], ).to_ase() result_2 = AtomsModel( None, None, - self.abcd.client.search(index="test_index")["hits"]["hits"][1]["_source"], + abcd.client.search(index="test_index")["hits"]["hits"][1]["_source"], ).to_ase() - self.assertEqual(atoms_1, result_1) - self.assertEqual(atoms_2, result_2) + assert atoms_1 == result_1 + assert atoms_2 == result_2 - def test_count(self): + def test_count(self, abcd): """ Test counting the number of documents in the database. """ - self.abcd.destroy() - self.abcd.create() - self.push_data() - self.push_data() - self.assertEqual(self.abcd.count(), 2) + abcd.destroy() + abcd.create() + self.push_data(abcd) + self.push_data(abcd) + assert abcd.count() == 2 - def test_property(self): + def test_property(self, abcd): """ Test getting values of a property from the database. """ - self.abcd.destroy() - self.abcd.create() + abcd.destroy() + abcd.create() xyz_1 = StringIO( """2 @@ -244,7 +235,7 @@ def test_property(self): atoms_1 = read(xyz_1, format="extxyz") assert isinstance(atoms_1, Atoms) atoms_1.set_cell([1, 1, 1]) - self.abcd.push(atoms_1, store_calc=False) + abcd.push(atoms_1, store_calc=False) xyz_2 = StringIO( """2 @@ -257,26 +248,26 @@ def test_property(self): atoms_2 = read(xyz_2, format="extxyz") assert isinstance(atoms_2, Atoms) atoms_2.set_cell([1, 1, 1]) - self.abcd.push(atoms_2, store_calc=False) + abcd.push(atoms_2, store_calc=False) - self.abcd.refresh() - prop = self.abcd.property("prop_1") + abcd.refresh() + prop = abcd.property("prop_1") expected_prop = ["test_1"] - self.assertEqual(prop, expected_prop) + assert prop == expected_prop - prop = self.abcd.property("energy") + prop = abcd.property("energy") expected_prop = [-5.0, -10.0] - self.assertEqual(prop[0], expected_prop[0]) - self.assertEqual(prop[1], expected_prop[1]) + assert prop[0] == expected_prop[0] + assert prop[1] == expected_prop[1] - def test_properties(self): + def test_properties(self, abcd): """ Test getting all properties from the database. """ - self.abcd.destroy() - self.abcd.create() - self.push_data() - props = self.abcd.properties() + abcd.destroy() + abcd.create() + self.push_data(abcd) + props = abcd.properties() expected_props = { "info": ["_vtk_test", "cell", "formula", "n_atoms", "pbc", "s", "volume"], "derived": [ @@ -290,14 +281,14 @@ def test_properties(self): ], "arrays": ["numbers", "positions"], } - self.assertEqual(props, expected_props) + assert props == expected_props - def test_count_property(self): + def test_count_property(self, abcd): """ Test counting values of specified properties from the database. """ - self.abcd.destroy() - self.abcd.create() + abcd.destroy() + abcd.create() xyz_1 = StringIO( """2 @@ -310,7 +301,7 @@ def test_count_property(self): atoms_1 = read(xyz_1, format="extxyz") assert isinstance(atoms_1, Atoms) atoms_1.set_cell([1, 1, 1]) - self.abcd.push(atoms_1) + abcd.push(atoms_1) xyz_2 = StringIO( """1 @@ -322,19 +313,19 @@ def test_count_property(self): atoms_2 = read(xyz_2, format="extxyz") assert isinstance(atoms_2, Atoms) atoms_2.set_cell([1, 1, 1]) - self.abcd.push(atoms_2) + abcd.push(atoms_2) - self.abcd.refresh() - self.assertEqual(self.abcd.count_property("prop_1"), {1: 1}) - self.assertEqual(self.abcd.count_property("n_atoms"), {1: 1, 2: 1}) - self.assertEqual(self.abcd.count_property("volume"), {1.0: 2}) + abcd.refresh() + assert abcd.count_property("prop_1") == {1: 1} + assert abcd.count_property("n_atoms") == {1: 1, 2: 1} + assert abcd.count_property("volume") == {1.0: 2} - def test_count_properties(self): + def test_count_properties(self, abcd): """ Test counting appearences of each property in documents in the database. """ - self.abcd.destroy() - self.abcd.create() + abcd.destroy() + abcd.create() xyz_1 = StringIO( """2 @@ -347,7 +338,7 @@ def test_count_properties(self): atoms_1 = read(xyz_1, format="extxyz") assert isinstance(atoms_1, Atoms) atoms_1.set_cell([1, 1, 1]) - self.abcd.push(atoms_1) + abcd.push(atoms_1) xyz_2 = StringIO( """2 @@ -360,10 +351,10 @@ def test_count_properties(self): atoms_2 = read(xyz_2, format="extxyz") assert isinstance(atoms_2, Atoms) atoms_2.set_cell([1, 1, 1]) - self.abcd.push(atoms_2) + abcd.push(atoms_2) - self.abcd.refresh() - props = self.abcd.count_properties() + abcd.refresh() + props = abcd.count_properties() expected_counts = { "prop_1": {"count": 1, "category": "info", "dtype": "scalar(str)"}, "prop_2": {"count": 1, "category": "info", "dtype": "scalar(str)"}, @@ -391,74 +382,66 @@ def test_count_properties(self): "volume": {"count": 2, "category": "derived", "dtype": "scalar(float)"}, } - self.assertEqual(props, expected_counts) + assert props == expected_counts - def test_add_property(self): + def test_add_property(self, abcd): """ Test adding a property to documents in the database. """ - self.abcd.destroy() - self.abcd.create() - self.push_data() - self.abcd.add_property({"TEST_PROPERTY": "TEST_VALUE"}) + abcd.destroy() + abcd.create() + self.push_data(abcd) + abcd.add_property({"TEST_PROPERTY": "TEST_VALUE"}) - self.abcd.refresh() - data = self.abcd.client.search(index="test_index") - self.assertEqual( - data["hits"]["hits"][0]["_source"]["TEST_PROPERTY"], "TEST_VALUE" - ) - self.assertIn( - "TEST_PROPERTY", data["hits"]["hits"][0]["_source"]["derived"]["info_keys"] - ) + abcd.refresh() + data = abcd.client.search(index="test_index") + assert data["hits"]["hits"][0]["_source"]["TEST_PROPERTY"] == "TEST_VALUE" + assert "TEST_PROPERTY" in data["hits"]["hits"][0]["_source"]["derived"]["info_keys"] - def test_rename_property(self): + def test_rename_property(self, abcd): """ Test renaming a property for documents in the database. """ - self.abcd.destroy() - self.abcd.create() - self.push_data() - self.abcd.add_property({"TEST_PROPERTY": "TEST_VALUE"}) - self.abcd.refresh() - self.abcd.rename_property("TEST_PROPERTY", "NEW_PROPERTY") - self.abcd.refresh() - - data = self.abcd.client.search(index="test_index") - self.assertEqual( - data["hits"]["hits"][0]["_source"]["NEW_PROPERTY"], "TEST_VALUE" - ) + abcd.destroy() + abcd.create() + self.push_data(abcd) + abcd.add_property({"TEST_PROPERTY": "TEST_VALUE"}) + abcd.refresh() + abcd.rename_property("TEST_PROPERTY", "NEW_PROPERTY") + abcd.refresh() + + data = abcd.client.search(index="test_index") + assert data["hits"]["hits"][0]["_source"]["NEW_PROPERTY"] == "TEST_VALUE" - def test_delete_property(self): + def test_delete_property(self, abcd): """ Test deleting a property from documents in the database. """ - self.abcd.destroy() - self.abcd.create() - self.push_data() + abcd.destroy() + abcd.create() + self.push_data(abcd) - self.abcd.add_property({"TEST_PROPERTY": "TEST_VALUE"}) - self.abcd.refresh() - data = self.abcd.client.search(index="test_index") - self.assertEqual( - data["hits"]["hits"][0]["_source"]["TEST_PROPERTY"], "TEST_VALUE" - ) + abcd.add_property({"TEST_PROPERTY": "TEST_VALUE"}) + abcd.refresh() + data = abcd.client.search(index="test_index") + assert data["hits"]["hits"][0]["_source"]["TEST_PROPERTY"] == "TEST_VALUE" - self.abcd.delete_property("TEST_PROPERTY") - self.abcd.refresh() - data = self.abcd.client.search(index="test_index") + abcd.delete_property("TEST_PROPERTY") + abcd.refresh() + data = abcd.client.search(index="test_index") with self.assertRaises(KeyError): data["hits"]["hits"][0]["_source"]["TEST_PROPERTY"] self.assertNotIn( "TEST_PROPERTY", data["hits"]["hits"][0]["_source"]["derived"]["info_keys"] ) - def test_get_items(self): + def test_get_items(self, abcd): """ Test getting a dictionary of values from documents in the database. """ - self.abcd.destroy() - self.abcd.create() - self.push_data() + abcd.destroy() + abcd.create() + self.push_data(abcd) expected_items = { "_id": None, @@ -501,8 +484,8 @@ def test_get_items(self): "username": None, } - self.abcd.refresh() - items = list(self.abcd.get_items())[0] + abcd.refresh() + items = list(abcd.get_items())[0] for key in expected_items: if key not in [ @@ -516,33 +499,28 @@ def test_get_items(self): if isinstance(expected_items[key], dict): for dict_key in expected_items[key]: if isinstance(expected_items[key][dict_key], list): - self.assertEqual( - set(expected_items[key][dict_key]), - set(items[key][dict_key]), - ) + assert set(expected_items[key][dict_key]) == set(items[key][dict_key]) else: - self.assertEqual( - expected_items[key][dict_key], items[key][dict_key] - ) + assert expected_items[key][dict_key] == items[key][dict_key] else: - self.assertEqual(expected_items[key], items[key]) + assert expected_items[key] == items[key] - def test_get_atoms(self): + def test_get_atoms(self, abcd): """ Test getting values from documents in the database as Atoms objects. """ - self.abcd.destroy() - self.abcd.create() - self.push_data() + abcd.destroy() + abcd.create() + self.push_data(abcd) expected_atoms = Atoms(symbols="Si2", pbc=False, cell=[1.0, 1.0, 1.0]) - self.assertEqual(expected_atoms, list(self.abcd.get_atoms())[0]) + assert expected_atoms == list(abcd.get_atoms())[0] - def test_query(self): + def test_query(self, abcd): """ Test querying documents in the database. """ - self.abcd.destroy() - self.abcd.create() + abcd.destroy() + abcd.create() xyz_1 = StringIO( """2 @@ -555,7 +533,7 @@ def test_query(self): atoms_1 = read(xyz_1, format="extxyz") assert isinstance(atoms_1, Atoms) atoms_1.set_cell([1, 1, 1]) - self.abcd.push(atoms_1) + abcd.push(atoms_1) xyz_2 = StringIO( """2 @@ -568,18 +546,14 @@ def test_query(self): atoms_2 = read(xyz_2, format="extxyz") assert isinstance(atoms_2, Atoms) atoms_2.set_cell([1, 1, 1]) - self.abcd.push(atoms_2) - self.abcd.refresh() + abcd.push(atoms_2) + abcd.refresh() query_dict = {"match": {"n_atoms": 2}} query_all = "volume: [0 TO 10]" query_1 = "prop_1: *" query_2 = "prop_2: *" - self.assertEqual(self.abcd.count(query_dict), 2) - self.assertEqual(self.abcd.count(query_all), 2) - self.assertEqual(self.abcd.count(query_1), 1) - self.assertEqual(self.abcd.count(query_2), 1) - - -if __name__ == "__main__": - unittest.main(verbosity=1, exit=False) + assert abcd.count(query_dict) == 2 + assert abcd.count(query_all) == 2 + assert abcd.count(query_1) == 1 + assert abcd.count(query_2) == 1