diff --git a/ush/python/pygfs/jedi/jedi.py b/ush/python/pygfs/jedi/jedi.py index c9419a4cd2..c0f8cb197a 100644 --- a/ush/python/pygfs/jedi/jedi.py +++ b/ush/python/pygfs/jedi/jedi.py @@ -14,8 +14,8 @@ logger = getLogger(__name__.split('.')[-1]) -jedi_key_list = ['rundir', 'exe_src', 'jedi_args', 'mpi_cmd', 'jcb_base_yaml', 'jcb_algo', 'jcb_algo_yaml'] - +required_jedi_keys = ['rundir', 'exe_src', 'mpi_cmd'] +optional_jedi_keys = ['jedi_args', 'jcb_base_yaml', 'jcb_algo', 'jcb_algo_yaml'] class Jedi: """ @@ -43,12 +43,12 @@ def __init__(self, config: Dict[str, Any]) -> None: # Make sure input dictionary for Jedi class constructor has the required keys if 'yaml_name' not in config: - logger.error(f"FATAL ERROR: Key 'yaml_name' not found in the nested dictionary") - raise KeyError(f"FATAL ERROR: Key 'yaml_name' not found in the nested dictionary") - for key in jedi_key_list: + logger.error(f"FATAL ERROR: Key 'yaml_name' not found in config") + raise KeyError(f"FATAL ERROR: Key 'yaml_name' not found in config") + for key in required_jedi_keys: if key not in config: - logger.error(f"FATAL ERROR: Key '{key}' not found in the nested dictionary") - raise KeyError(f"FATAL ERROR: Key '{key}' not found in the nested dictionary") + logger.error(f"FATAL ERROR: Required key '{key}' not found in config") + raise KeyError(f"FATAL ERROR: Required key '{key}' not found in config") # Create the configuration dictionary for JEDI object local_dict = AttrDict( @@ -60,6 +60,11 @@ def __init__(self, config: Dict[str, Any]) -> None: ) self.jedi_config = AttrDict(**config, **local_dict) + # Set optional keys in jedi_config to None if not already present + for key in optional_jedi_keys: + if key not in self.jedi_config: + self.jedi_config[key] = None + # Save a copy of jedi_config self._jedi_config = self.jedi_config.deepcopy() @@ -113,7 +118,7 @@ def execute(self) -> None: exec_cmd = Executable(self.jedi_config.mpi_cmd) exec_cmd.add_default_arg(self.jedi_config.exe) - if self.jedi_config.jedi_args: + if not self.jedi_config.jedi_args is None: for arg in self.jedi_config.jedi_args: exec_cmd.add_default_arg(arg) exec_cmd.add_default_arg(self.jedi_config.yaml) @@ -147,20 +152,20 @@ def render_jcb(self, task_config: AttrDict, algorithm: Optional[str] = None) -> """ # Fill JCB base YAML template and build JCB config dictionary - if self.jedi_config.jcb_base_yaml: + if not self.jedi_config.jcb_base_yaml is None: jcb_config = parse_j2yaml(self.jedi_config.jcb_base_yaml, task_config) else: - logger.error(f"FATAL ERROR: JEDI configuration dictionary must contain jcb_base_yaml.") - raise KeyError(f"FATAL ERROR: JEDI configuration dictionary must contain jcb_base_yaml.") + logger.error(f"FATAL ERROR: JCB base YAML must be specified in order to render YAML using JCB") + raise KeyError(f"FATAL ERROR: JCB base YAML must be specified in order to render YAML using JCB") # Add JCB algorithm YAML, if it exists, to JCB config dictionary - if self.jedi_config.jcb_algo_yaml: + if not self.jedi_config.jcb_algo_yaml is None: jcb_config.update(parse_j2yaml(self.jedi_config.jcb_algo_yaml, task_config)) # Set algorithm in JCB config dictionary - if algorithm: + if not algorithm is None: jcb_config['algorithm'] = algorithm - elif self.jedi_config.jcb_algo: + elif not self.jedi_config.jcb_algo is None: jcb_config['algorithm'] = self.jedi_config.jcb_algo elif 'algorithm' in jcb_config: pass @@ -196,7 +201,7 @@ def link_exe(self) -> None: @staticmethod @logit(logger) - def get_jedi_dict(jedi_config_yaml: str, task_config: AttrDict, expected_keys: Optional[list] = None): + def get_jedi_dict(jedi_config_yaml: str, task_config: AttrDict, expected_block_names: Optional[list] = None): """Get dictionary of Jedi objects from YAML specifying their configuration dictionaries Parameters @@ -205,7 +210,8 @@ def get_jedi_dict(jedi_config_yaml: str, task_config: AttrDict, expected_keys: O path to YAML specifying configuration dictionaries for Jedi objects task_config : str attribute-dictionary of all configuration variables associated with a GDAS task - + expected_block_names (optional) : str + list of names of blocks expected to be in jedi_config_yaml YAML file Returns ---------- @@ -219,23 +225,31 @@ def get_jedi_dict(jedi_config_yaml: str, task_config: AttrDict, expected_keys: O jedi_config_dict = parse_j2yaml(jedi_config_yaml, task_config) # Loop through dictionary of Jedi configuration dictionaries - for yaml_name in jedi_config_dict: - # Make sure all required keys present or set to None - jedi_config_dict[yaml_name]['yaml_name'] = yaml_name - for key in jedi_key_list: - if key not in jedi_config_dict[yaml_name]: - jedi_config_dict[yaml_name][key] = None + for block_name in jedi_config_dict: + # yaml_name key is set to name for this block + jedi_config_dict[block_name]['yaml_name'] = block_name + + # Make sure all required keys present + for key in required_jedi_keys: + if key not in jedi_config_dict[block_name]: + logger.error(f"FATAL ERROR: Required key {key} not found in {jedi_config_yaml} for block {block_name}.") + raise KeyError(f"FATAL ERROR: Required key {key} not found in {jedi_config_yaml} for block {block_name}.") + + # Set optional keys to None + for key in optional_jedi_keys: + if key not in jedi_config_dict[block_name]: + jedi_config_dict[block_name][key] = None # Construct JEDI object - jedi_dict[yaml_name] = Jedi(jedi_config_dict[yaml_name]) - - # Make sure jedi_dict has the keys we expect - if expected_keys: - for jedi_dict_key in expected_keys: - if jedi_dict_key not in jedi_dict: - logger.error(f"FATAL ERROR: {jedi_dict_key} not present {jedi_config_yaml}") - raise Exception(f"FATAL ERROR: {jedi_dict_key} not present {jedi_config_yaml}") - if len(jedi_dict) > len(expected_keys): + jedi_dict[block_name] = Jedi(jedi_config_dict[block_name]) + + # Make sure jedi_dict has the blocks we expect + if expected_block_names: + for block_name in expected_block_names: + if block_name not in jedi_dict: + logger.error(f"FATAL ERROR: Expected block {block_name} not present {jedi_config_yaml}") + raise Exception(f"FATAL ERROR: Expected block {block_name} not present {jedi_config_yaml}") + if len(jedi_dict) > len(expected_block_names): logger.error(f"FATAL ERROR: {jedi_config_yaml} specifies more Jedi objects than expected.") raise Exception(f"FATAL ERROR: {jedi_config_yaml} specifies more Jedi objects than expected.")