Skip to content

Commit

Permalink
Fix sandbox stage-in vs. localization interplaye, fix #168.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Oct 13, 2023
1 parent e2fc8fc commit d301b94
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 30 deletions.
65 changes: 43 additions & 22 deletions law/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def run(self):
from law.target.local import LocalFileTarget
from law.util import (
no_value, uncolored, make_list, multi_match, human_duration, open_compat, join_generators,
TeeStream, perf_counter,
TeeStream, perf_counter, empty_context,
)
from law.logger import get_logger

Expand Down Expand Up @@ -486,44 +486,65 @@ def localize(fn, opts, task, *args, **kwargs):
that are passed as keyword arguments to the respective localization method. Does **not** accept
generator functions.
"""
# get actual input and outputs
input_struct = task.input() if opts["input"] else None
output_struct = task.output() if opts["output"] else None

# store original input and output methods
input_orig = task.input
output_orig = task.output

# input and output kwargs
input_kwargs = opts["input_kwargs"] or {}
output_kwargs = opts["output_kwargs"] or {}

# default modes
input_kwargs.setdefault("mode", "r")
output_kwargs.setdefault("mode", "w")
input_orig = task.__getattribute__("input", proxy=False) if opts["input"] else None
output_orig = task.__getattribute__("output", proxy=False) if opts["output"] else None

# wrap input context
input_context = empty_context
if opts["input"]:
def input_context():
input_struct = task.input()
input_kwargs = opts["input_kwargs"] or {}
input_kwargs.setdefault("mode", "r")
return localize_file_targets(input_struct, **input_kwargs)

# wrap output context
output_context = empty_context
if opts["output"]:
def output_context():
output_struct = task.output()
output_kwargs = opts["output_kwargs"] or {}
output_kwargs.setdefault("mode", "w")
return localize_file_targets(output_struct, **output_kwargs)

try:
# localize both target structs
with localize_file_targets(input_struct, **input_kwargs) as localized_inputs, \
localize_file_targets(output_struct, **output_kwargs) as localized_outputs:
# localize both target contexts
with input_context() as localized_inputs, output_context() as localized_outputs:
# patch the input method to always return the localized inputs
if opts["input"]:
def input_patched(self):
return localized_inputs
task.input = input_patched.__get__(task)

task.input = _patch_localized_method(task, input_patched)

# patch the output method to always return the localized outputs
if opts["output"]:
def output_patched(self):
return localized_outputs
task.output = output_patched.__get__(task)

task.output = _patch_localized_method(task, output_patched)

return fn(task, *args, **kwargs)

finally:
# restore the methods
task.input = input_orig
task.output = output_orig
if input_orig is not None:
task.input = input_orig
if output_orig is not None:
task.output = output_orig


def _patch_localized_method(task, func):
# add a flag to func
func._patched_localized_method = True

# bind to task
return func.__get__(task)


def _is_patched_localized_method(func):
return getattr(func, "_patched_localized_method", False) is True


@factory(sandbox=None, accept_generator=True)
Expand Down
28 changes: 20 additions & 8 deletions law/sandbox/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,35 +605,47 @@ def is_root_task(self):
return is_root and _sandbox_is_root_task

def _staged_input(self):
from law.decorator import _is_patched_localized_method

if not _sandbox_stagein_dir:
raise Exception(
"LAW_SANDBOX_STAGEIN_DIR must not be empty in a sandbox when target "
"stage-in is required",
)

# get the original inputs
inputs = self.__getattribute__("input", proxy=False)()
input_func = self.__getattribute__("input", proxy=False)
inputs = input_func()

# create the struct of staged inputs
staged_inputs = create_staged_target_struct(_sandbox_stagein_dir, inputs)
# when input_func is a patched method from a localization decorator, just return the inputs
# since the decorator already triggered the stage-in
if _is_patched_localized_method(input_func):
return inputs

# apply the stage-in mask
# create the struct of staged inputs and apply the stage-in mask
staged_inputs = create_staged_target_struct(_sandbox_stagein_dir, inputs)
return mask_struct(self.sandbox_stagein(), staged_inputs, inputs)

def _staged_output(self):
from law.decorator import _is_patched_localized_method

if not _sandbox_stageout_dir:
raise Exception(
"LAW_SANDBOX_STAGEOUT_DIR must not be empty in a sandbox when target "
"stage-out is required",
)

# get the original outputs
outputs = self.__getattribute__("output", proxy=False)()
output_func = self.__getattribute__("output", proxy=False)
outputs = output_func()

# create the struct of staged outputs
staged_outputs = create_staged_target_struct(_sandbox_stageout_dir, outputs)
# when output_func is a patched method from a localization decorator, just return the
# outputs since the decorator already triggered the stage-out
if _is_patched_localized_method(output_func):
return outputs

# apply the stage-out mask
# create the struct of staged outputs and apply the stage-out mask
staged_outputs = create_staged_target_struct(_sandbox_stageout_dir, outputs)
return mask_struct(self.sandbox_stageout(), staged_outputs, outputs)

@property
Expand Down

0 comments on commit d301b94

Please sign in to comment.