diff --git a/acquire/acquire.py b/acquire/acquire.py index f0393bec..30435ea7 100644 --- a/acquire/acquire.py +++ b/acquire/acquire.py @@ -236,13 +236,11 @@ class Module: EXEC_ORDER = ExecutionOrder.DEFAULT @classmethod - def run(cls, target, cli_args, collector): + def run(cls, target: Target, cli_args: argparse.Namespace, collector: Collector): desc = cls.DESC or cls.__name__.lower() log.info("*** Acquiring %s", desc) - collector.bind(cls) - - try: + with collector.bind_module(cls): collector.collect(cls.SPEC) spec_ext = cls.get_spec_additions(target) @@ -250,8 +248,6 @@ def run(cls, target, cli_args, collector): collector.collect(list(spec_ext)) cls._run(target, collector) - finally: - collector.unbind() @classmethod def get_spec_additions(cls, target): @@ -1555,8 +1551,7 @@ def run(cls, target, cli_args, collector): specs = cls.get_specs(cli_args) - collector.bind(cls) - try: + with collector.bind_module(cls): start = time.time() path_hashes = collect_hashes(target, specs, path_filters=cls.DEFAULT_FILE_FILTERS) @@ -1567,8 +1562,6 @@ def run(cls, target, cli_args, collector): csv_compressed_bytes, ) log.info("Hashing is done, %s files processed in %.2f secs", rows_count, (time.time() - start)) - finally: - collector.unbind() @classmethod def get_specs(cls, cli_args): @@ -1623,8 +1616,7 @@ def run(cls, target: Target, cli_args: dict[str, any], collector: Collector): handle_types = cli_args.handle_types - collector.bind(cls) - try: + with collector.bind_module(cls): handles = collect_open_handles(handle_types) csv_compressed_handles = serialize_handles_into_csv(handles) @@ -1633,8 +1625,6 @@ def run(cls, target: Target, cli_args: dict[str, any], collector: Collector): csv_compressed_handles, ) log.info("Collecting open handles is done.") - finally: - collector.unbind() def print_disks_overview(target): diff --git a/acquire/collector.py b/acquire/collector.py index 62305ee0..19076104 100644 --- a/acquire/collector.py +++ b/acquire/collector.py @@ -195,6 +195,14 @@ def __enter__(self) -> Collector: def __exit__(self, *args, **kwargs) -> None: self.close() + @contextmanager + def bind_module(self, module: Type) -> Collector: + try: + self.bind(module) + yield self + finally: + self.unbind() + @contextmanager def file_filter(self, filter: Callable[[fsutil.TargetPath], bool]) -> Collector: try: