diff --git a/dataimporter/cli/main.py b/dataimporter/cli/main.py index 9e9f704..6ac41c4 100644 --- a/dataimporter/cli/main.py +++ b/dataimporter/cli/main.py @@ -39,7 +39,7 @@ def get_status(config: Config): console.log("Queue size:", view.count()) try: - database = importer.get_splitgill_database(view) + database = importer.get_database(view) except ValueError: console.log(Rule()) continue diff --git a/dataimporter/importer.py b/dataimporter/importer.py index 6728811..bb188ff 100644 --- a/dataimporter/importer.py +++ b/dataimporter/importer.py @@ -2,7 +2,7 @@ from functools import partial from itertools import groupby from pathlib import Path -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Union from splitgill.manager import SplitgillClient, SplitgillDatabase from splitgill.model import Record @@ -141,16 +141,20 @@ def get_view(self, name: str) -> Optional[View]: return view return None - def get_splitgill_database(self, view: View) -> SplitgillDatabase: + def get_database(self, view: Union[str, View]) -> Optional[SplitgillDatabase]: """ Returns a new SplitgillDatabase instance for the given view. If the view doesn't - have an associated SplitgillDatabase name, then a ValueError is raised. + have an associated SplitgillDatabase name, then None is returned. - :param view: a view - :return: a SplitgillDatabase instance + :param view: a View instance or a view's name + :return: a SplitgillDatabase instance or None """ + if isinstance(view, str): + view = self.get_view(view) + if view is None: + return None if not view.has_database: - raise ValueError("View does not have a sg_name") + return None return SplitgillDatabase(view.sg_name, self.client) def queue_changes(self, records: Iterable[SourceRecord], store_name: str): @@ -266,7 +270,7 @@ def add_to_mongo(self, view_name: str, everything: bool = False) -> Optional[int self.release_records(now()) view = self.get_view(view_name) - database = self.get_splitgill_database(view) + database = self.get_database(view) if everything: changed_records = view.iter_all() @@ -303,7 +307,7 @@ def sync_to_elasticsearch(self, view_name: str, resync: bool = False): haven't changed """ view = self.get_view(view_name) - database = self.get_splitgill_database(view) + database = self.get_database(view) database.sync(resync=resync) def force_merge(self, view_name: str) -> dict: @@ -315,7 +319,7 @@ def force_merge(self, view_name: str) -> dict: :return: """ view = self.get_view(view_name) - database = self.get_splitgill_database(view) + database = self.get_database(view) client = self.client.elasticsearch return client.options(request_timeout=None).indices.forcemerge( index=database.indices.wildcard, diff --git a/tests/test_importer.py b/tests/test_importer.py index 29bf7af..4044b81 100644 --- a/tests/test_importer.py +++ b/tests/test_importer.py @@ -291,7 +291,7 @@ def test_add_to_mongo_and_sync_to_elasticsearch_artefact(self, config: Config): importer.add_to_mongo(name) - database = importer.get_splitgill_database(importer.get_view(name)) + database = importer.get_database(importer.get_view(name)) assert database.get_committed_version() == to_timestamp( datetime(2023, 10, 20, 11, 4, 31) ) @@ -361,7 +361,7 @@ def test_add_to_mongo_and_sync_to_elasticsearch_indexlot(self, config: Config): importer.add_to_mongo(name) - database = importer.get_splitgill_database(importer.get_view(name)) + database = importer.get_database(importer.get_view(name)) assert database.get_committed_version() == to_timestamp( datetime(2023, 10, 20, 11, 4, 31) ) @@ -437,7 +437,7 @@ def test_add_to_mongo_and_sync_to_elasticsearch_specimen(self, config: Config): importer.add_to_mongo(name) - database = importer.get_splitgill_database(importer.get_view(name)) + database = importer.get_database(importer.get_view(name)) assert database.get_committed_version() == to_timestamp( datetime(2023, 10, 20, 11, 4, 31) @@ -492,7 +492,7 @@ def test_add_to_mongo_and_sync_to_elasticsearch_mss(self, config: Config): importer.add_to_mongo(name) - database = importer.get_splitgill_database(importer.get_view(name)) + database = importer.get_database(importer.get_view(name)) assert database.get_committed_version() == to_timestamp( datetime(2023, 10, 20, 11, 4, 31) @@ -560,7 +560,7 @@ def test_add_to_mongo_and_sync_to_elasticsearch_preparation(self, config: Config importer.add_to_mongo(name) - database = importer.get_splitgill_database(importer.get_view(name)) + database = importer.get_database(importer.get_view(name)) assert database.get_committed_version() == to_timestamp( datetime(2023, 10, 20, 11, 4, 31) )