diff --git a/law/decorator.py b/law/decorator.py index ed9b28b9..fd6e619d 100644 --- a/law/decorator.py +++ b/law/decorator.py @@ -44,7 +44,7 @@ def run(self): from law.target.local import LocalFileTarget from law.util import ( no_value, uncolored, make_list, multi_match, human_duration, open_compat, join_generators, - TeeStream, perf_counter, + TeeStream, perf_counter, empty_context, ) from law.logger import get_logger @@ -486,44 +486,65 @@ def localize(fn, opts, task, *args, **kwargs): that are passed as keyword arguments to the respective localization method. Does **not** accept generator functions. """ - # get actual input and outputs - input_struct = task.input() if opts["input"] else None - output_struct = task.output() if opts["output"] else None - # store original input and output methods - input_orig = task.input - output_orig = task.output - - # input and output kwargs - input_kwargs = opts["input_kwargs"] or {} - output_kwargs = opts["output_kwargs"] or {} - - # default modes - input_kwargs.setdefault("mode", "r") - output_kwargs.setdefault("mode", "w") + input_orig = task.__getattribute__("input", proxy=False) if opts["input"] else None + output_orig = task.__getattribute__("output", proxy=False) if opts["output"] else None + + # wrap input context + input_context = empty_context + if opts["input"]: + def input_context(): + input_struct = task.input() + input_kwargs = opts["input_kwargs"] or {} + input_kwargs.setdefault("mode", "r") + return localize_file_targets(input_struct, **input_kwargs) + + # wrap output context + output_context = empty_context + if opts["output"]: + def output_context(): + output_struct = task.output() + output_kwargs = opts["output_kwargs"] or {} + output_kwargs.setdefault("mode", "w") + return localize_file_targets(output_struct, **output_kwargs) try: - # localize both target structs - with localize_file_targets(input_struct, **input_kwargs) as localized_inputs, \ - localize_file_targets(output_struct, **output_kwargs) as localized_outputs: + # localize both target contexts + with input_context() as localized_inputs, output_context() as localized_outputs: # patch the input method to always return the localized inputs if opts["input"]: def input_patched(self): return localized_inputs - task.input = input_patched.__get__(task) + + task.input = _patch_localized_method(task, input_patched) # patch the output method to always return the localized outputs if opts["output"]: def output_patched(self): return localized_outputs - task.output = output_patched.__get__(task) + + task.output = _patch_localized_method(task, output_patched) return fn(task, *args, **kwargs) finally: # restore the methods - task.input = input_orig - task.output = output_orig + if input_orig is not None: + task.input = input_orig + if output_orig is not None: + task.output = output_orig + + +def _patch_localized_method(task, func): + # add a flag to func + func._patched_localized_method = True + + # bind to task + return func.__get__(task) + + +def _is_patched_localized_method(func): + return getattr(func, "_patched_localized_method", False) is True @factory(sandbox=None, accept_generator=True) diff --git a/law/sandbox/base.py b/law/sandbox/base.py index 42527c66..1c05b17a 100644 --- a/law/sandbox/base.py +++ b/law/sandbox/base.py @@ -605,6 +605,8 @@ def is_root_task(self): return is_root and _sandbox_is_root_task def _staged_input(self): + from law.decorator import _is_patched_localized_method + if not _sandbox_stagein_dir: raise Exception( "LAW_SANDBOX_STAGEIN_DIR must not be empty in a sandbox when target " @@ -612,15 +614,21 @@ def _staged_input(self): ) # get the original inputs - inputs = self.__getattribute__("input", proxy=False)() + input_func = self.__getattribute__("input", proxy=False) + inputs = input_func() - # create the struct of staged inputs - staged_inputs = create_staged_target_struct(_sandbox_stagein_dir, inputs) + # when input_func is a patched method from a localization decorator, just return the inputs + # since the decorator already triggered the stage-in + if _is_patched_localized_method(input_func): + return inputs - # apply the stage-in mask + # create the struct of staged inputs and apply the stage-in mask + staged_inputs = create_staged_target_struct(_sandbox_stagein_dir, inputs) return mask_struct(self.sandbox_stagein(), staged_inputs, inputs) def _staged_output(self): + from law.decorator import _is_patched_localized_method + if not _sandbox_stageout_dir: raise Exception( "LAW_SANDBOX_STAGEOUT_DIR must not be empty in a sandbox when target " @@ -628,12 +636,16 @@ def _staged_output(self): ) # get the original outputs - outputs = self.__getattribute__("output", proxy=False)() + output_func = self.__getattribute__("output", proxy=False) + outputs = output_func() - # create the struct of staged outputs - staged_outputs = create_staged_target_struct(_sandbox_stageout_dir, outputs) + # when output_func is a patched method from a localization decorator, just return the + # outputs since the decorator already triggered the stage-out + if _is_patched_localized_method(output_func): + return outputs - # apply the stage-out mask + # create the struct of staged outputs and apply the stage-out mask + staged_outputs = create_staged_target_struct(_sandbox_stageout_dir, outputs) return mask_struct(self.sandbox_stageout(), staged_outputs, outputs) @property