diff --git a/easy_slurm/run/submit.py b/easy_slurm/run/submit.py index b5ced92..dc1e9b0 100644 --- a/easy_slurm/run/submit.py +++ b/easy_slurm/run/submit.py @@ -187,32 +187,41 @@ def parse_args(argv=None): parser = argparse.ArgumentParser() for argument in ARGUMENTS: - kwargs = {k: v for k, v in argument.items() if k != "args"} - parser.add_argument(*argument["args"], **kwargs) + parser.add_argument( + *argument["args"], + **{k: v for k, v in argument.items() if k != "args"}, + ) args = parser.parse_args(argv) return args -def main(argv=None): - args = parse_args(argv) +def load_job_config(args): + # Load from CLI. + job_config = {k: v for k, v in vars(args).items() if v is not None} + # Load from file. if args.job: with open(args.job) as f: - job_config = yaml.safe_load(f) + job_config.update(yaml.safe_load(f)) + # Extra config for formatting purposes. if args.config: with open(args.config) as f: job_config["config"] = yaml.safe_load(f) - job_config = { - **{k: v for k, v in vars(args).items() if v is not None}, - **job_config, - } + # Discard unexpected keys. + for k in job_config: + if k not in JOB_CONFIG_KEYS: + del job_config[k] + + return job_config - job_config = {k: v for k, v in job_config.items() if k in JOB_CONFIG_KEYS} +def main(argv=None): + args = parse_args(argv) + job_config = load_job_config(args) submit_job(**job_config)