Skip to content

Commit

Permalink
Fix filter job (#3211)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

Currently when adding multiple filters, Job API would fail. This PR
fixes the issue.

A filter can be added to a set of tasks (e.g. ["train", "validate"]).
All task sets must be unique, meaning that task sets must not have
intersections.

For example, if you add a filter X to task set ["train", "validate"],
you cannot add another filter Y to task set ["train", "eval"], because
these two task share the same element "train". Of course, if you make
each task set to contain a single task, then it's okay to add any
filters.

You can add any number of filters to the same task set.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Quick tests passed locally by running `./runtest.sh`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated.

---------

Co-authored-by: Chester Chen <[email protected]>
  • Loading branch information
yanchengnv and chesterxgchen authored Feb 7, 2025
1 parent 50e3400 commit e9c5749
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 36 deletions.
45 changes: 37 additions & 8 deletions nvflare/job_config/base_app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class BaseAppConfig(ABC):
def __init__(self) -> None:
super().__init__()

self.task_data_filters: [(List[str], Filter)] = []
self.task_result_filters: [(List[str], Filter)] = []
self.task_data_filters = [] # list of tuples: (task_set, list of filters)
self.task_result_filters = [] # list of tuples: (task_set, list of filters)
self.components: Dict[str, object] = {}
self.ext_scripts = []
self.ext_dirs = []
Expand Down Expand Up @@ -66,14 +66,43 @@ def add_ext_dir(self, ext_dir: str):

self.ext_dirs.append(ext_dir)

def _add_task_filter(self, tasks, filter, filters):
@staticmethod
def _add_task_filter(tasks, filter, taskset_filters: list):
"""Add a filter for a set of tasks.
Args:
tasks: the tasks that the filter will be added to.
filter: the filter to be added.
taskset_filters: this is a list of tuples. Each tuple contains a taskset and a list of filters
already added to the taskset.
Returns: None
We first check whether the "tasks" already matches an entry's taskset in taskset_filters.
If so, then add the filter to the entry.
Otherwise, we then check whether the "tasks" overlaps with any entry's taskset. If so, this is not allowed
and an exception will be raised.
If the "tasks" doesn't exist nor conflicts with any entry in taskset_filters, we add a new entry to
taskset_filters.
"""
if not isinstance(filter, Filter):
raise RuntimeError(f"filter must be type of Filter, but got {filter.__class__}")
for task in tasks:
for fd in filters:
if task in fd.tasks:
raise RuntimeError(f"Task {task} already defined in the task filters.")
filters.append((tasks, filter))

# check whether "tasks" already exist
tasks = set(tasks)
for task_set, filter_list in taskset_filters:
if tasks == task_set:
# found it
filter_list.append(filter)
return
elif tasks.intersection(task_set):
# the tasks intersect with this task_set - not allowed
raise RuntimeError(f"cannot add filters for '{tasks}' since it overlaps task set '{task_set}'")

# no conflicting task_set
taskset_filters.append((tasks, [filter]))

def add_file_source(self, src_path: str, dest_dir=None):
self.file_sources.append((src_path, dest_dir))
56 changes: 28 additions & 28 deletions nvflare/job_config/fed_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,34 +316,34 @@ def _get_base_app(self, custom_dir, app, app_config):
"args": self._get_args(component, custom_dir),
}
)
app_config["task_data_filters"] = []
for tasks, filter in app.task_data_filters:
app_config["task_data_filters"].append(
{
"tasks": tasks,
"filters": [
{
# self._get_filters(task_filter.filter, custom_dir)
"path": self._get_class_path(filter, custom_dir),
"args": self._get_args(filter, custom_dir),
}
],
}
)
app_config["task_result_filters"] = []
for tasks, filter in app.task_result_filters:
app_config["task_result_filters"].append(
{
"tasks": tasks,
"filters": [
{
# self._get_filters(result_filer.filter, custom_dir)
"path": self._get_class_path(filter, custom_dir),
"args": self._get_args(filter, custom_dir),
}
],
}
)

app_config["task_data_filters"] = self._process_filters(app.task_data_filters, custom_dir)
app_config["task_result_filters"] = self._process_filters(app.task_result_filters, custom_dir)

def _process_filters(self, taskset_filters: list, custom_dir):
"""Process taskset_filters into app filter configuration
Args:
taskset_filters: the list of tuples that contain taskset/filters association.
custom_dir: custom dir of the app.
Returns: app filter configuration that is a list of dicts, each dict represents a taskset/filters
association.
"""
app_config_filters = []
for task_set, filter_list in taskset_filters:
filters = []
for f in filter_list:
filters.append(
{
"path": self._get_class_path(f, custom_dir),
"args": self._get_args(f, custom_dir),
}
)

app_config_filters.append({"tasks": list(task_set), "filters": filters})
return app_config_filters

def _get_args(self, component, custom_dir):
args = {}
Expand Down

0 comments on commit e9c5749

Please sign in to comment.