diff --git a/lib/ramble/ramble/context.py b/lib/ramble/ramble/context.py new file mode 100644 index 000000000..b1e2a47a0 --- /dev/null +++ b/lib/ramble/ramble/context.py @@ -0,0 +1,31 @@ +# Copyright 2022-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 or the MIT license +# , at your +# option. This file may not be copied, modified, or distributed +# except according to those terms. + + +class Context(object): + """Class to represent a context + + This class contains variable definitions to store any individual context + (such as application, workload, or experiment) and logic to merge in + additional contexts by order of precedence.""" + + def __init__(self): + """Constructor for a Context + + Create a Context object, which holds context attributes. + """ + self.env_variables = None + self.variables = None + self.internals = None + self.templates = None + self.chained_experiments = None + self.modifiers = None + self.context_names = None + self.exclude = None + self.zips = None + self.matrices = None diff --git a/lib/ramble/ramble/experiment_set.py b/lib/ramble/ramble/experiment_set.py index 05be402e8..12c00c730 100644 --- a/lib/ramble/ramble/experiment_set.py +++ b/lib/ramble/ramble/experiment_set.py @@ -22,6 +22,7 @@ import ramble.error import ramble.renderer import ramble.util.matrices +import ramble.context class ExperimentSet(object): @@ -49,38 +50,13 @@ def __init__(self, workspace): self.chained_experiments = {} self.chained_order = [] self._workspace = workspace - - self._env_variables = {} - self._variables = {} - self._internals = {} - self._templates = {} - self._chained_experiments = {} - self._modifiers = {} - self._context_names = {} + self._context = {} for context in self._contexts: - self._context_names[context] = None - self._env_variables[context] = None - self._variables[context] = None - self._internals[context] = None - self._templates[context] = None - self._chained_experiments[context] = None - self._modifiers[context] = None - - self._variables[self._contexts.base] = {} - self._variables[self._contexts.required] = {} - - self._exclude = { - self._contexts.experiment: None - } - - self._zips = { - self._contexts.experiment: None - } + self._context[context] = ramble.context.Context() - self._matrices = { - self._contexts.experiment: None - } + self._context[self._contexts.base].variables = {} + self._context[self._contexts.required].variables = {} self.read_config_vars(workspace) @@ -138,11 +114,11 @@ def get_config_env_vars(self, workspace): def set_base_var(self, var, val): """Set a base variable definition""" - self._variables[self._contexts.base][var] = val + self._context[self._contexts.base].variables[var] = val def set_required_var(self, var, val): """Set a required variable definition""" - self._variables[self._contexts.required][var] = val + self._context[self._contexts.required].variables[var] = val def _set_context(self, context, name, variables, env_variables, internals, template=None, chained_experiments=None, modifiers=None): @@ -152,13 +128,13 @@ def _set_context(self, context, name, variables, env_variables, internals, f'Context {context} is not a valid context.' ) - self._context_names[context] = name - self._variables[context] = variables - self._env_variables[context] = env_variables - self._internals[context] = internals - self._templates[context] = template - self._chained_experiments[context] = chained_experiments - self._modifiers[context] = modifiers + self._context[context].context_names = name + self._context[context].variables = variables + self._context[context].env_variables = env_variables + self._context[context].internals = internals + self._context[context].templates = template + self._context[context].chained_experiments = chained_experiments + self._context[context].modifiers = modifiers def set_application_context(self, application_name, application_variables, @@ -231,16 +207,16 @@ def set_experiment_context(self, experiment_name_template, experiment_chained_experiments, experiment_modifiers) - self._exclude[self._contexts.experiment] = experiment_exclude - self._zips[self._contexts.experiment] = experiment_zips - self._matrices[self._contexts.experiment] = experiment_matrices + self._context[self._contexts.experiment].exclude = experiment_exclude + self._context[self._contexts.experiment].zips = experiment_zips + self._context[self._contexts.experiment].matrices = experiment_matrices self._ingest_experiments() @property def application_namespace(self): """Property to return application namespace (application name)""" - if self._context_names[self._contexts.application]: - return self._context_names[self._contexts.application] + if self._context[self._contexts.application].context_names: + return self._context[self._contexts.application].context_names return None @property @@ -250,7 +226,7 @@ def workload_namespace(self): Workload namespaces are of the form: application_name.workload_name """ app_ns = self.application_namespace - wl_ns = self._context_names[self._contexts.workload] + wl_ns = self._context[self._contexts.workload].context_names if app_ns and wl_ns: return f'{app_ns}.{wl_ns}' @@ -264,7 +240,7 @@ def experiment_namespace(self): Experiment namespaces are of the form: application_name.workload_name.experiment_name """ wl_ns = self.workload_namespace - exp_ns = self._context_names[self._contexts.experiment] + exp_ns = self._context[self._contexts.experiment].context_names if wl_ns and exp_ns: return f'{wl_ns}.{exp_ns}' @@ -351,40 +327,40 @@ def _ingest_experiments(self): namespace.executables] for context in self._contexts: - if context in self._variables and self._variables[context]: - context_variables.update(self._variables[context]) - if context in self._env_variables and self._env_variables[context]: - ordered_env_variables.append(self._env_variables[context]) - if self._internals[context]: + if context in self._context and self._context[context].variables: + context_variables.update(self._context[context].variables) + if context in self._context and self._context[context].env_variables: + ordered_env_variables.append(self._context[context].env_variables) + if self._context[context].internals: for internal_section in internal_sections: - if internal_section in self._internals[context]: - if isinstance(self._internals[context][internal_section], dict): + if internal_section in self._context[context].internals: + if isinstance(self._context[context].internals[internal_section], dict): if internal_section not in merged_internals: merged_internals[internal_section] = {} - section_dict = self._internals[context][internal_section] + section_dict = self._context[context].internals[internal_section] for key, val in section_dict.items(): merged_internals[internal_section][key] = val - elif isinstance(self._internals[context][internal_section], list): + elif isinstance(self._context[context].internals[internal_section], list): if internal_section not in merged_internals: merged_internals[internal_section] = [] merged_internals[internal_section].extend( - self._internals[context][internal_section]) + self._context[context].internals[internal_section]) else: merged_internals[internal_section] = \ - self._internals[context][internal_section] - if self._chained_experiments[context]: - for chained_exp in self._chained_experiments[context]: + self._context[context].internals[internal_section] + if self._context[context].chained_experiments: + for chained_exp in self._context[context].chained_experiments: merged_chained_experiments.append(chained_exp.copy()) - if self._modifiers[context]: - for modifier in self._modifiers[context]: + if self._context[context].modifiers: + for modifier in self._context[context].modifiers: merged_mods.append(modifier.copy()) - if self._templates[context] is not None: - is_template = self._templates[context] + if self._context[context].templates is not None: + is_template = self._context[context].templates for context in self._contexts: var_name = f'{context.name}_name' - if self._context_names[context] not in context_variables: - context_variables[var_name] = self._context_names[context] + if self._context[context].context_names not in context_variables: + context_variables[var_name] = self._context[context].context_names # Set namespaces context_variables['application_namespace'] = self.application_namespace @@ -416,16 +392,16 @@ def _ingest_experiments(self): render_group = ramble.renderer.RenderGroup('experiment', 'create') render_group.variables = context_variables - render_group.zips = self._zips[self._contexts.experiment] - render_group.matrices = self._matrices[self._contexts.experiment] + render_group.zips = self._context[self._contexts.experiment].zips + render_group.matrices = self._context[self._contexts.experiment].matrices excluded_experiments = set() - if self._exclude[self._contexts.experiment]: + if self._context[self._contexts.experiment].exclude: exclude_group = ramble.renderer.RenderGroup('experiment', 'exclude') exclude_group.copy_contents(render_group) perform_explicit_exclude = \ exclude_group.from_dict(experiment_template_name, - self._exclude[self._contexts.experiment]) + self._context[self._contexts.experiment].exclude) if perform_explicit_exclude: for exclude_exp_vars in renderer.render_objects(exclude_group): @@ -436,9 +412,9 @@ def _ingest_experiments(self): excluded_experiments.add(exclude_exp_name) exclude_where = [] - if self._exclude[self._contexts.experiment]: - if namespace.where in self._exclude[self._contexts.experiment]: - exclude_where = self._exclude[self._contexts.experiment][namespace.where] + if self._context[self._contexts.experiment].exclude: + if namespace.where in self._context[self._contexts.experiment].exclude: + exclude_where = self._context[self._contexts.experiment].exclude[namespace.where] rendered_experiments = set() for experiment_vars in \ diff --git a/lib/ramble/ramble/test/experiment_set.py b/lib/ramble/ramble/test/experiment_set.py index 54e1e46db..65b8f453d 100644 --- a/lib/ramble/ramble/test/experiment_set.py +++ b/lib/ramble/ramble/test/experiment_set.py @@ -652,7 +652,7 @@ def test_processes_per_node_correct_defaults(mutable_mock_workspace_path): with ramble.workspace.read('test') as ws: exp_set = ramble.experiment_set.ExperimentSet(ws) # Remove workspace vars, which default to a `processes_per_node = -1` definition. - exp_set._variables[exp_set._contexts.workspace] = {} + exp_set._context[exp_set._contexts.workspace].variables = {} app_name = 'basic' app_vars = { @@ -761,7 +761,7 @@ def test_reserved_keywords_error_in_experiment(mutable_mock_workspace_path, var, with ramble.workspace.read('test') as ws: exp_set = ramble.experiment_set.ExperimentSet(ws) # Remove workspace vars, which default to a `processes_per_node = -1` definition. - exp_set._variables[exp_set._contexts.base] = {} + exp_set._context[exp_set._contexts.base].variables = {} app_name = 'basic' app_vars = { @@ -807,8 +807,8 @@ def test_missing_required_keyword_errors(mutable_mock_workspace_path, var, capsy with ramble.workspace.read('test') as ws: exp_set = ramble.experiment_set.ExperimentSet(ws) for context in exp_set._contexts: - if exp_set._variables[context] and var in exp_set._variables[context]: - del exp_set._variables[context][var] + if exp_set._context[context].variables and var in exp_set._context[context].variables: + del exp_set._context[context].variables[var] app_name = 'basic' app_vars = {