diff --git a/submitit/core/submission.py b/submitit/core/submission.py index 042ffaa..d37708d 100644 --- a/submitit/core/submission.py +++ b/submitit/core/submission.py @@ -46,12 +46,12 @@ def process_job(folder: Union[Path, str]) -> None: env._handle_signals(paths, delayed) result = delayed.result() with utils.temporary_save_path(paths.result_pickle) as tmppath: # save somewhere else, and move - utils.pickle_dump(("success", result), tmppath) + utils.cloudpickle_dump(("success", result), tmppath) logger.info("Job completed successfully") except Exception as error: # TODO: check pickle methods for capturing traceback; pickling and raising try: with utils.temporary_save_path(paths.result_pickle) as tmppath: - utils.pickle_dump(("error", traceback.format_exc()), tmppath) + utils.cloudpickle_dump(("error", traceback.format_exc()), tmppath) except Exception as dumperror: logger.error(f"Could not dump error:\n{error}\n\nbecause of {dumperror}") logger.error("Submitted job triggered an exception") diff --git a/submitit/core/test_core.py b/submitit/core/test_core.py index acf232f..b0b0f90 100644 --- a/submitit/core/test_core.py +++ b/submitit/core/test_core.py @@ -145,11 +145,11 @@ def test_fake_job(tmp_path: Path) -> None: f.write("blublu") assert job.stderr() == "blublu" # result - utils.pickle_dump(("success", 12), job.paths.result_pickle) + utils.cloudpickle_dump(("success", 12), job.paths.result_pickle) assert job.result() == 12 # exception assert job.exception() is None - utils.pickle_dump(("error", "blublu"), job.paths.result_pickle) + utils.cloudpickle_dump(("error", "blublu"), job.paths.result_pickle) assert isinstance(job.exception(), Exception) with pytest.raises(core.utils.FailedJobError): job.result() diff --git a/submitit/core/utils.py b/submitit/core/utils.py index 69e659c..5db4a01 100644 --- a/submitit/core/utils.py +++ b/submitit/core/utils.py @@ -273,11 +273,6 @@ def pickle_load(filename: Union[str, Path]) -> Any: return pickle.load(ifile) -def pickle_dump(obj: Any, filename: Union[str, Path]) -> None: - with open(filename, "wb") as ofile: - pickle.dump(obj, ofile, pickle.HIGHEST_PROTOCOL) - - def cloudpickle_dump(obj: Any, filename: Union[str, Path]) -> None: with open(filename, "wb") as ofile: cloudpickle.dump(obj, ofile, pickle.HIGHEST_PROTOCOL) diff --git a/submitit/local/test_local.py b/submitit/local/test_local.py index a20209e..a8b7282 100644 --- a/submitit/local/test_local.py +++ b/submitit/local/test_local.py @@ -78,6 +78,15 @@ def failing_job() -> None: assert "Failed on purpose" in traceback +def test_pickle_output_from_main(tmp_path: Path) -> None: + class MyClass: + pass + + executor = local.LocalExecutor(tmp_path) + job = executor.submit(MyClass.__call__) + assert isinstance(job.result(), MyClass) + + def test_get_first_task_error(tmp_path: Path) -> None: def flaky() -> None: job_env = job_environment.JobEnvironment()