diff --git a/kronfluence/computer/computer.py b/kronfluence/computer/computer.py index a1723b5..e405d78 100644 --- a/kronfluence/computer/computer.py +++ b/kronfluence/computer/computer.py @@ -39,7 +39,7 @@ find_executable_batch_size, make_indices_partition, ) -from kronfluence.utils.exceptions import FactorsNotFoundError, UnsupportableModuleError +from kronfluence.utils.exceptions import FactorsNotFoundError, UnsupportableModuleError, TrackedModuleNotFoundError from kronfluence.utils.logger import PassThroughProfiler, Profiler, get_logger, get_time from kronfluence.utils.save import ( FACTOR_ARGUMENTS_NAME, @@ -80,8 +80,9 @@ def __init__( self.model.eval() self.task = task - tracked_module_names = get_tracked_module_names(self.model) - if len(tracked_module_names) == 0: + try: + tracked_module_names = get_tracked_module_names(self.model) + except TrackedModuleNotFoundError: error_msg = ( f"No tracked modules found in the provided model: {self.model}. " f"Please make sure to run `prepare_model` before passing it in to the "