From e9c5749e8445a3115535664d4893604b27adbf49 Mon Sep 17 00:00:00 2001 From: Yan Cheng <58191769+yanchengnv@users.noreply.github.com> Date: Fri, 7 Feb 2025 18:12:59 -0500 Subject: [PATCH] Fix filter job (#3211) Fixes # . ### Description Currently when adding multiple filters, Job API would fail. This PR fixes the issue. A filter can be added to a set of tasks (e.g. ["train", "validate"]). All task sets must be unique, meaning that task sets must not have intersections. For example, if you add a filter X to task set ["train", "validate"], you cannot add another filter Y to task set ["train", "eval"], because these two task share the same element "train". Of course, if you make each task set to contain a single task, then it's okay to add any filters. You can add any number of filters to the same task set. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com> --- nvflare/job_config/base_app_config.py | 45 +++++++++++++++++---- nvflare/job_config/fed_job_config.py | 56 +++++++++++++-------------- 2 files changed, 65 insertions(+), 36 deletions(-) diff --git a/nvflare/job_config/base_app_config.py b/nvflare/job_config/base_app_config.py index 36138032a3..ded1daba15 100644 --- a/nvflare/job_config/base_app_config.py +++ b/nvflare/job_config/base_app_config.py @@ -28,8 +28,8 @@ class BaseAppConfig(ABC): def __init__(self) -> None: super().__init__() - self.task_data_filters: [(List[str], Filter)] = [] - self.task_result_filters: [(List[str], Filter)] = [] + self.task_data_filters = [] # list of tuples: (task_set, list of filters) + self.task_result_filters = [] # list of tuples: (task_set, list of filters) self.components: Dict[str, object] = {} self.ext_scripts = [] self.ext_dirs = [] @@ -66,14 +66,43 @@ def add_ext_dir(self, ext_dir: str): self.ext_dirs.append(ext_dir) - def _add_task_filter(self, tasks, filter, filters): + @staticmethod + def _add_task_filter(tasks, filter, taskset_filters: list): + """Add a filter for a set of tasks. + + Args: + tasks: the tasks that the filter will be added to. + filter: the filter to be added. + taskset_filters: this is a list of tuples. Each tuple contains a taskset and a list of filters + already added to the taskset. + + Returns: None + + We first check whether the "tasks" already matches an entry's taskset in taskset_filters. + If so, then add the filter to the entry. + + Otherwise, we then check whether the "tasks" overlaps with any entry's taskset. If so, this is not allowed + and an exception will be raised. + + If the "tasks" doesn't exist nor conflicts with any entry in taskset_filters, we add a new entry to + taskset_filters. + """ if not isinstance(filter, Filter): raise RuntimeError(f"filter must be type of Filter, but got {filter.__class__}") - for task in tasks: - for fd in filters: - if task in fd.tasks: - raise RuntimeError(f"Task {task} already defined in the task filters.") - filters.append((tasks, filter)) + + # check whether "tasks" already exist + tasks = set(tasks) + for task_set, filter_list in taskset_filters: + if tasks == task_set: + # found it + filter_list.append(filter) + return + elif tasks.intersection(task_set): + # the tasks intersect with this task_set - not allowed + raise RuntimeError(f"cannot add filters for '{tasks}' since it overlaps task set '{task_set}'") + + # no conflicting task_set + taskset_filters.append((tasks, [filter])) def add_file_source(self, src_path: str, dest_dir=None): self.file_sources.append((src_path, dest_dir)) diff --git a/nvflare/job_config/fed_job_config.py b/nvflare/job_config/fed_job_config.py index c92fe890e5..ba58d399ab 100644 --- a/nvflare/job_config/fed_job_config.py +++ b/nvflare/job_config/fed_job_config.py @@ -316,34 +316,34 @@ def _get_base_app(self, custom_dir, app, app_config): "args": self._get_args(component, custom_dir), } ) - app_config["task_data_filters"] = [] - for tasks, filter in app.task_data_filters: - app_config["task_data_filters"].append( - { - "tasks": tasks, - "filters": [ - { - # self._get_filters(task_filter.filter, custom_dir) - "path": self._get_class_path(filter, custom_dir), - "args": self._get_args(filter, custom_dir), - } - ], - } - ) - app_config["task_result_filters"] = [] - for tasks, filter in app.task_result_filters: - app_config["task_result_filters"].append( - { - "tasks": tasks, - "filters": [ - { - # self._get_filters(result_filer.filter, custom_dir) - "path": self._get_class_path(filter, custom_dir), - "args": self._get_args(filter, custom_dir), - } - ], - } - ) + + app_config["task_data_filters"] = self._process_filters(app.task_data_filters, custom_dir) + app_config["task_result_filters"] = self._process_filters(app.task_result_filters, custom_dir) + + def _process_filters(self, taskset_filters: list, custom_dir): + """Process taskset_filters into app filter configuration + + Args: + taskset_filters: the list of tuples that contain taskset/filters association. + custom_dir: custom dir of the app. + + Returns: app filter configuration that is a list of dicts, each dict represents a taskset/filters + association. + + """ + app_config_filters = [] + for task_set, filter_list in taskset_filters: + filters = [] + for f in filter_list: + filters.append( + { + "path": self._get_class_path(f, custom_dir), + "args": self._get_args(f, custom_dir), + } + ) + + app_config_filters.append({"tasks": list(task_set), "filters": filters}) + return app_config_filters def _get_args(self, component, custom_dir): args = {}