From 4f3a41e7d5be30256a9459c7bbb89acb5cb7f01c Mon Sep 17 00:00:00 2001 From: Josh Humphries Date: Thu, 15 Aug 2024 11:36:43 +0100 Subject: [PATCH] refactor: use exceptions when failing to get stores, views, and databases I've gone back and forth on this, but I think this is overall cleaner. At the very least, at least it is now consistent. --- dataimporter/cli/main.py | 8 ++++--- dataimporter/importer.py | 48 ++++++++++++++++++++++++++++------------ 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/dataimporter/cli/main.py b/dataimporter/cli/main.py index 72e6e90..82d58ea 100644 --- a/dataimporter/cli/main.py +++ b/dataimporter/cli/main.py @@ -9,7 +9,8 @@ from dataimporter.cli.maintenance import maintenance_group from dataimporter.cli.portal import portal_group from dataimporter.cli.utils import with_config, console -from dataimporter.importer import DataImporter +from dataimporter.cli.view import view_group +from dataimporter.importer import DataImporter, ViewDoesNotHaveDatabase from dataimporter.lib.config import Config @@ -38,8 +39,9 @@ def get_status(config: Config): console.log("Backing store", view.store.name) console.log("Queue size:", view.count()) - database = importer.get_database(view) - if database is None: + try: + database = importer.get_database(view) + except ViewDoesNotHaveDatabase: console.log(Rule()) continue diff --git a/dataimporter/importer.py b/dataimporter/importer.py index d987263..69b6844 100644 --- a/dataimporter/importer.py +++ b/dataimporter/importer.py @@ -30,6 +30,24 @@ from dataimporter.lib.view import View +class StoreNotFound(Exception): + def __init__(self, name: str): + super().__init__(f"Store {name} not found") + self.name = name + + +class ViewNotFound(Exception): + def __init__(self, name: str): + super().__init__(f"View {name} not found") + self.name = name + + +class ViewDoesNotHaveDatabase(Exception): + def __init__(self, view: View): + super().__init__(f"View {view.name} does not have a Splitgill database") + self.view = view.name + + class DataImporter: """ Main manager class for the data importer. @@ -117,44 +135,46 @@ def __init__(self, config: Config): # this is where store the last date we have fully imported from EMu self.emu_status = EMuStatus(config.data_path / "emu_last_date.txt") - def get_store(self, name: str) -> Optional[Store]: + def get_store(self, name: str) -> Store: """ - Get the store with the given name. If the store doesn't exist, None is returned. + Get the store with the given name. If the store doesn't exist, a StoreNotFound + exception is raised. :param name: the name of the store - :return: the Store instance or None + :return: the Store instance """ for store in self.stores: if store.name == name: return store - return None + raise StoreNotFound(name) - def get_view(self, name: str) -> Optional[View]: + def get_view(self, name: str) -> View: """ - Get the view with the given name. If the view doesn't exist, None is returned. + Get the view with the given name. If the view doesn't exist, a ViewNotFound + exception is raised. :param name: the name of the view - :return: the View instance or None + :return: the View instance """ for view in self.views: if view.name == name: return view - return None + raise ViewNotFound(name) - def get_database(self, view: Union[str, View]) -> Optional[SplitgillDatabase]: + def get_database(self, view: Union[str, View]) -> SplitgillDatabase: """ Returns a new SplitgillDatabase instance for the given view. If the view doesn't - have an associated SplitgillDatabase name, then None is returned. + have an associated SplitgillDatabase name, then a ViewDoesNotHaveDatabase + exception is raised. If the view parameter is passed as a str and the view does + not exist, a ViewNotFound exception is raised. :param view: a View instance or a view's name - :return: a SplitgillDatabase instance or None + :return: a SplitgillDatabase instance """ if isinstance(view, str): view = self.get_view(view) - if view is None: - return None if not view.has_database: - return None + raise ViewDoesNotHaveDatabase(view) return SplitgillDatabase(view.sg_name, self.client) def queue_changes(self, records: Iterable[SourceRecord], store_name: str):