Skip to content

Commit

Permalink
Resolving objectives issue introduced with introduction of pass search (
Browse files Browse the repository at this point in the history
#1585)

## Resolving objectives issue introduced with introduction of pass
search

Olive allows multiple passes under the "passes" key in config where each
entry can have dictate its own evaluator config i.e. the evaluation
config to use for that specific pass. However, with pass search, this
becomes an issue because each pass within the group can dictate
conflicting objectives and goals. Circumventing the issue by collecting
all the objectives across the group with the last one in the list
winning if named the same. Also, handle the case where not all
objectives are generated as part of the
post evaluation signal.

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [x] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
shaahji authored Feb 1, 2025
1 parent c22f68b commit f1fe7bf
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 175 deletions.
17 changes: 12 additions & 5 deletions olive/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,18 +391,25 @@ def _get_search_space_objectives(
input_model_config: ModelConfig,
input_model_id: str,
accelerator_spec: "AcceleratorSpec",
) -> Dict[str, List[Dict[str, Any]]]:
objectives_by_pass_name: Dict[str, List[Dict[str, Any]]] = {}
) -> Dict[str, Dict[str, Dict[str, Any]]]:
# NOTE: Olive config doesn't easily lend itself to enforcing one evaluator across
# multiple pass run configs since each can have its own. That freedom creates some
# bad unexpected scenarios for search. If two or more pass run configs in the same
# pass group dictates different objectives (and thus different goals), there is no
# way to resolve them. To keep things simple for the time being, the objectives
# across all pass run configs within a pass group are merged by name (so the last
# one) in the group will win.
objectives_by_pass_name: Dict[str, Dict[str, Dict[str, Any]]] = {}
objectives_by_evaluator_name: Dict[str, Dict[str, Any]] = {}
for pass_name, passes_configs in self.input_passes_configs.items():
objectives_by_pass_name[pass_name] = passes_objectives = []
objectives_by_pass_name[pass_name] = passes_objectives = {}
for pass_config in passes_configs:
evaluator_config = pass_config.evaluator or self.evaluator_config
if evaluator_config.name not in objectives_by_evaluator_name:
objectives_by_evaluator_name[evaluator_config.name] = self.resolve_objectives(
input_model_config, input_model_id, evaluator_config.metrics, accelerator_spec
)
passes_objectives.append(objectives_by_evaluator_name[evaluator_config.name])
passes_objectives.update(objectives_by_evaluator_name[evaluator_config.name])

accelerator_objectives: Dict[str, Any] = {}
for objectives in objectives_by_evaluator_name.values():
Expand Down Expand Up @@ -523,7 +530,7 @@ def resolve_objectives(
"goal": goals.get(metric_key),
"priority": sub_type.priority,
}
return dict(sorted(objective_dict.items(), key=lambda x: x[1]["priority"]))
return OrderedDict(sorted(objective_dict.items(), key=lambda x: x[1]["priority"]))

def resolve_goals(
self,
Expand Down
25 changes: 16 additions & 9 deletions olive/search/samplers/optuna_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import sys
from abc import abstractmethod
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -39,15 +40,15 @@ def __init__(
self,
search_space: SearchSpace,
config: Optional[Union[Dict[str, Any], ConfigBase]] = None,
objectives: Dict[str, Dict[str, Any]] = None,
):
super().__init__(search_space, config)
super().__init__(search_space, config, objectives)

# Initialize the searcher
self._sampler = self._create_sampler()
# TODO(olivedev): There is no absolute direction to set.
# directions = ["maximize" if hib else "minimize" for hib in self._higher_is_betters]
# self._study = optuna.create_study(directions=directions, sampler=self._sampler)
self._study = optuna.create_study(sampler=self._sampler)
directions = ["maximize" if self._higher_is_betters[name] else "minimize" for name in self._objectives]
self._study = optuna.create_study(directions=directions, sampler=self._sampler)

self._num_samples_suggested = 0
self._search_point_index_to_trail_id = {}

Expand Down Expand Up @@ -133,12 +134,18 @@ def _get_search_point_values(
else:
raise ValueError(f"Unsupported parameter type: {type(param)}")

def record_feedback_signal(
self, search_point_index: int, objectives: Dict[str, dict], signal: "MetricResult", should_prune: bool = False
):
def record_feedback_signal(self, search_point_index: int, signal: "MetricResult", should_prune: bool = False):
trial_id = self._search_point_index_to_trail_id[search_point_index]
if should_prune:
self._study.tell(trial_id, state=optuna.trial.TrialState.PRUNED)
else:
values = [signal[objective].value for objective in objectives]
values = []
for name in self._objectives:
if name in signal:
values.append(signal[name].value)
elif self._higher_is_betters[name]:
values.append(-sys.maxsize - 1)
else:
values.append(sys.maxsize)

self._study.tell(trial_id, values)
3 changes: 2 additions & 1 deletion olive/search/samplers/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def __init__(
self,
search_space: SearchSpace,
config: Optional[Union[Dict[str, Any], ConfigBase]] = None,
objectives: Dict[str, Dict[str, Any]] = None,
):
super().__init__(search_space, config)
super().__init__(search_space, config, objectives)

self._rng = Random(self.config.seed)
self._search_points = [None] * len(self._search_space)
Expand Down
14 changes: 11 additions & 3 deletions olive/search/samplers/search_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from abc import abstractmethod
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Type, Union

from olive.common.auto_config import AutoConfigClass
Expand Down Expand Up @@ -33,10 +34,19 @@ def __init__(
self,
search_space: SearchSpace,
config: Optional[Union[Dict[str, Any], ConfigBase]] = None,
objectives: Dict[str, Dict[str, Any]] = None,
):
super().__init__(config)

self._search_space = search_space
self._config = config

# Order the objectives based on priority, and then by name
objectives = objectives or {}
self._objectives = OrderedDict(sorted(objectives.items(), key=lambda entry: (entry[1]["priority"], entry[0])))
self._higher_is_betters = {
name: objective.get("higher_is_better") or False for name, objective in self._objectives.items()
}

@property
@abstractmethod
Expand All @@ -63,7 +73,5 @@ def suggest(self) -> "SearchPoint":
"""Suggest a new configuration to try."""
return None

def record_feedback_signal(
self, search_point_index: int, objectives: Dict[str, dict], signal: "MetricResult", should_prune: bool = False
):
def record_feedback_signal(self, search_point_index: int, signal: "MetricResult", should_prune: bool = False):
"""Report the result of a configuration."""
3 changes: 2 additions & 1 deletion olive/search/samplers/sequential_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ def __init__(
self,
search_space: SearchSpace,
config: Optional[Union[Dict[str, Any], ConfigBase]] = None,
objectives: Dict[str, Dict[str, Any]] = None,
):
super().__init__(search_space, config)
super().__init__(search_space, config, objectives)

self._index = 0

Expand Down
94 changes: 48 additions & 46 deletions olive/search/search_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import sys
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, List, Tuple

import numpy as np
Expand All @@ -12,16 +13,28 @@


class SearchResults:
def __init__(self):
self._results: Tuple[MetricResult, List[str], Dict[str, Any]] = []
def __init__(self, objectives: Dict[str, Dict[str, Any]] = None):
# Order the objectives based on priority, and then by name
objectives = objectives or {}
self._objectives = OrderedDict(sorted(objectives.items(), key=lambda entry: (entry[1]["priority"], entry[0])))

self._goals = {}
self._multipliers = {}
self._higher_is_betters = {}
for name, objective in self._objectives.items():
if objective.get("goal") is not None:
self._goals[name] = objective["goal"]

self._higher_is_betters[name] = objective.get("higher_is_better") or False
self._multipliers[name] = 1 if self._higher_is_betters[name] else -1

self._results: Tuple[MetricResult, List[str]] = []
self._sorted_indices: List[int] = []

def record_feedback_signal(
self, search_point_index: int, objectives: Dict[str, dict], result: "MetricResult", model_ids: List[str]
):
def record_feedback_signal(self, search_point_index: int, result: "MetricResult", model_ids: List[str]):
"""Record the evaluation result of a search point."""
self._results += [None] * ((search_point_index + 1) - len(self._results))
self._results[search_point_index] = (result, model_ids, objectives)
self._results[search_point_index] = (result, model_ids)

def meets_goals(self, search_point_index: int) -> bool:
"""Check if the result satisfies the constraints."""
Expand All @@ -31,16 +44,15 @@ def meets_goals(self, search_point_index: int) -> bool:
if not self._results[search_point_index]:
return False

result, _, objectives = self._results[search_point_index]
goals = {name: obj["goal"] for name, obj in objectives.items() if obj.get("goal") is not None}
if not goals:
if not self._goals:
return True # if goals are not set, always return True

# multiplier for each objective and goals
multipliers = {
name: 1 if objective.get("higher_is_better", False) else -1 for name, objective in objectives.items()
}
return all((multipliers[obj] * result[obj].value) >= (multipliers[obj] * goal) for obj, goal in goals.items())
result, _ = self._results[search_point_index]
return all(
(self._multipliers[name] * result[name].value) >= (self._multipliers[name] * goal)
for name, goal in self._goals.items()
if name in result
)

def sort(self, apply_goals: bool = False):
indices, results = self._get_results_list(apply_goals)
Expand All @@ -67,58 +79,48 @@ def get_next_best_result(self, start_index: int) -> Tuple[int, int, List[str]]:
if next_best_index >= len(self._sorted_indices):
return None, None, None

_, model_ids, _ = self._results[self._sorted_indices[next_best_index]]
_, model_ids = self._results[self._sorted_indices[next_best_index]]
return next_best_index, self._sorted_indices[next_best_index], model_ids

def _get_results_list(self, apply_goals: bool = False) -> Tuple[List[int], List[float]]:
"""Return the results as a tuple of indices and values.
Values are multiplied by the objective multiplier so that higher is better for all objectives.
"""
all_objectives = {}
for spi, entry in enumerate(self._results):
if entry and (not apply_goals or self.meets_goals(spi)):
_, _, objectives = entry
for name in objectives:
if name in all_objectives:
assert all_objectives[name] == objectives[name].get(
"higher_is_better", False
), "Conflicting values for higher_is_better across same named objectives"
else:
all_objectives[name] = objectives[name].get("higher_is_better", False)

indices = []
"""Return the results as a tuple of indices and values."""
values = []
if not all_objectives:
indices = []
if not self._objectives:
# If no objectives, then use the indices of the valid results in no specific order
indices = [spi for spi, entry in enumerate(self._results) if entry]
return indices, values

# NOTE: values array need to be packed but a simple loop thru' each entry could
# possibly create a zagged array if the number of objectives are different.
# possibly create a jagged array if the number of actual objectives in the signal
# are different from the expected ones. To circumvent the issue, we use min/max
# depending on the higher_is_better values for the missing expected objectives
# to deprioritize that objective while sorting.

for spi, entry in enumerate(self._results):
if entry and (not apply_goals or self.meets_goals(spi)):
result, _, objectives = entry
if objectives:
indices.append(spi)
v = []
for name, hib in all_objectives.items():
if name in objectives:
v.append((1 if hib else -1) * result[name].value)
else:
v.append(-sys.maxsize - 1 if hib else sys.maxsize)
values.append(v)
result, _ = entry

v = []
for name in self._objectives:
if name in result:
# Values are scaled for comparison such that higher is better for all objectives.
v.append(self._multipliers[name] * result[name].value)
else:
v.append(-sys.maxsize - 1 if self._higher_is_betters[name] else sys.maxsize)

values.append(v)
indices.append(spi)

return indices, values

def to_json(self):
"""Return a json representation of the search results."""
return {"results": self._results}
return {"objectives": self._objectives, "results": self._results}

@classmethod
def from_json(cls, json_dict):
"""Create a SearchResults object from a json representation."""
search_results = cls()
search_results = cls(json_dict["objectives"])
search_results._results = json_dict["results"]
return search_results
Loading

0 comments on commit f1fe7bf

Please sign in to comment.