Skip to content

Commit

Permalink
store worker wrapper in Job
Browse files Browse the repository at this point in the history
  • Loading branch information
michelwi committed Jan 26, 2024
1 parent 399b8d9 commit 138aead
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
3 changes: 2 additions & 1 deletion sisyphus/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def __new__(cls: Type[T], *args, **kwargs) -> T:

# Init
def _sis_init(self, args, kwargs, parsed_args):

for key, arg in parsed_args.items():
if isinstance(arg, Job):
logging.warning(
Expand All @@ -211,6 +210,7 @@ def _sis_init(self, args, kwargs, parsed_args):
self._sis_outputs = {}
self._sis_keep_value = None
self._sis_hold_job = False
self._sis_worker_wrapper = gs.worker_wrapper

self._sis_blocks = set()
self._sis_kwargs = parsed_args
Expand Down Expand Up @@ -316,6 +316,7 @@ def __getstate__(self):
"current_block",
"_sis_cleanable_cache",
"_sis_cleaned_or_not_cleanable",
"_sis_worker_wrapper",
]:
if key in d:
del d[key]
Expand Down
5 changes: 4 additions & 1 deletion sisyphus/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,5 +481,8 @@ def get_worker_call(self, task_id=None):
call += [gs.CMD_WORKER, os.path.relpath(self.path()), self.name()]
if task_id is not None:
call.append(str(task_id))
call = gs.worker_wrapper(getattr(self, "_job", None), self.name(), call)
if hasattr(self, "_job"):
call = self._job._sis_worker_wrapper(self._job, self.name(), call)
else:
call = gs.worker_wrapper(None, self.name(), call)
return call

0 comments on commit 138aead

Please sign in to comment.