Skip to content

Commit

Permalink
feat: support some metrics being 'None' without stopping training (fo…
Browse files Browse the repository at this point in the history
…undation-model-stack#169)

Some metrics may not be available at the time of rule evaluation.
Add some more unit tests for the same conditions.

Signed-off-by: Harikrishnan Balagopal <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
HarikrishnanBalagopal authored Jul 5, 2024
1 parent b655e1a commit 06e8cbc
Show file tree
Hide file tree
Showing 10 changed files with 401 additions and 57 deletions.
3 changes: 3 additions & 0 deletions tests/data/trainercontroller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
TRAINER_CONFIG_TEST_INVALID_METRIC_YAML = os.path.join(
_DATA_DIR, "loss_invalid_metric.yaml"
)
TRAINER_CONFIG_TEST_UNAVAILABLE_METRIC_YAML = os.path.join(
_DATA_DIR, "loss_unavailable_metric.yaml"
)
TRAINER_CONFIG_TEST_CUSTOM_METRIC_YAML = os.path.join(
_DATA_DIR, "loss_custom_metric.yaml"
)
Expand Down
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/loss_unavailable_metric.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller-unavailable-metric
triggers:
- on_step_end
rule: loss < 1.0
operations:
- hfcontrols.should_training_stop
13 changes: 13 additions & 0 deletions tests/trainercontroller/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
10 changes: 5 additions & 5 deletions tests/trainercontroller/custom_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
# https://spdx.dev/learn/handling-license-info/

# Standard
from dataclasses import dataclass
from typing import Any

# Third Party
from transformers import TrainerState
import pytest

# Local
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler
Expand All @@ -31,22 +29,24 @@ class CustomMetric(MetricHandler):
"""Implements a custom metric for testing"""

def __init__(self, **kwargs):
"""Initializes the metric handler, by registering the event list and arguments with base handler.
"""Initializes the metric handler,
by registering the event list and arguments with base handler.
Args:
kwargs: List of arguments (key, value)-pairs
"""
super().__init__(events=["on_log"], **kwargs)

def validate(self) -> bool:
"""Validate the training arguments (e.g logging_steps) are compatible with the computation of this metric.
"""Validate the training arguments (e.g logging_steps)
are compatible with the computation of this metric.
Returns:
bool
"""
return True

def compute(self, state: TrainerState = None, **kwargs) -> Any:
def compute(self, _: TrainerState = None, **__) -> Any:
"""Just returns True (for testing purposes only).
Args:
Expand Down
10 changes: 3 additions & 7 deletions tests/trainercontroller/custom_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@
# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Standard
from dataclasses import dataclass
from typing import Any

# Third Party
from transformers import TrainerControl, TrainerState
import pytest
from transformers import TrainerControl

# Local
from tuning.trainercontroller.operations import Operation
Expand All @@ -30,14 +26,14 @@
class CustomOperation(Operation):
"""Implements a custom operation for testing"""

def __init__(self, **kwargs):
def __init__(self, **_):
"""Initializes the custom operation class.
Args:
kwargs: List of arguments (key, value)-pairs
"""
super().__init__()

def should_perform_action_xyz(self, control: TrainerControl, **kwargs):
def should_perform_action_xyz(self, control: TrainerControl, **_):
"""This method performs a set training stop flag action.
Args:
Expand Down
10 changes: 3 additions & 7 deletions tests/trainercontroller/custom_operation_invalid_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@
# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Standard
from dataclasses import dataclass
from typing import Any

# Third Party
from transformers import TrainerControl, TrainerState
import pytest
from transformers import TrainerControl

# Local
from tuning.trainercontroller.operations import Operation
Expand All @@ -30,14 +26,14 @@
class CustomOperationInvalidAction(Operation):
"""Implements a custom operation for testing"""

def __init__(self, **kwargs):
def __init__(self, **_):
"""Initializes the custom operation class.
Args:
kwargs: List of arguments (key, value)-pairs
"""
super().__init__()

def should_(self, control: TrainerControl, **kwargs):
def should_(self, control: TrainerControl, **_):
"""This method defines an action within an invalid name.
Args:
Expand Down
52 changes: 34 additions & 18 deletions tests/trainercontroller/test_tuning_trainercontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class InputData:

def _setup_data() -> InputData:
"""
Sets up the test data for the test cases. This includes the logs, arguments for training and state
Sets up the test data for the test cases.
This includes the logs, arguments for training and state
of the training.
Returns:
Expand Down Expand Up @@ -85,7 +86,7 @@ def test_loss_on_threshold():
tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control)
# Trigger rule and test the condition
tc_callback.on_log(args=test_data.args, state=test_data.state, control=control)
assert control.should_training_stop == True
assert control.should_training_stop is True


def test_loss_on_threshold_with_trainer_state():
Expand Down Expand Up @@ -117,7 +118,7 @@ def test_exposed_metrics():
tc_callback.on_evaluate(
args=test_data.args, state=test_data.state, control=control, metrics=metrics
)
assert control.should_training_stop == True
assert control.should_training_stop is True


def test_incorrect_source_event_exposed_metrics():
Expand All @@ -143,7 +144,7 @@ def test_incorrect_source_event_exposed_metrics():
str(exception_handler.value).strip("'")
== "Specified source event [on_incorrect_event] is invalid for EvalMetrics"
)
assert control.should_training_stop == True
assert control.should_training_stop is True


def test_custom_metric_handler():
Expand All @@ -160,7 +161,7 @@ def test_custom_metric_handler():
tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control)
# Trigger rule and test the condition
tc_callback.on_log(args=test_data.args, state=test_data.state, control=control)
assert control.should_training_stop == True
assert control.should_training_stop is True


def test_custom_operation_handler():
Expand All @@ -177,7 +178,7 @@ def test_custom_operation_handler():
tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control)
# Trigger rule and test the condition
tc_callback.on_log(args=test_data.args, state=test_data.state, control=control)
assert control.should_training_stop == True
assert control.should_training_stop is True


def test_custom_operation_invalid_action_handler():
Expand All @@ -197,9 +198,9 @@ def test_custom_operation_invalid_action_handler():
)
# Trigger rule and test the condition
tc_callback.on_log(args=test_data.args, state=test_data.state, control=control)
assert (
str(exception_handler.value).strip("'")
== "Invalid operation customoperation.should_ for control loss-controller-custom-operation-invalid-action"
assert str(exception_handler.value).strip("'") == (
"Invalid operation customoperation.should_ for control"
+ " loss-controller-custom-operation-invalid-action"
)


Expand Down Expand Up @@ -282,9 +283,9 @@ def test_invalid_trigger():
)
# Trigger rule and test the condition
tc_callback.on_log(args=test_data.args, state=test_data.state, control=control)
assert (
str(exception_handler.value).strip("'")
== "Controller loss-controller-invalid-trigger has an invalid event (log_it_all_incorrect_trigger_name)"
assert str(exception_handler.value).strip("'") == (
"Controller loss-controller-invalid-trigger has"
+ " an invalid event (log_it_all_incorrect_trigger_name)"
)


Expand All @@ -304,9 +305,9 @@ def test_invalid_operation():
)
# Trigger rule and test the condition
tc_callback.on_log(args=test_data.args, state=test_data.state, control=control)
assert (
str(exception_handler.value).strip("'")
== "Invalid operation missingop.should_training_stop for control loss-controller-invalid-operation"
assert str(exception_handler.value).strip("'") == (
"Invalid operation missingop.should_training_stop"
+ " for control loss-controller-invalid-operation"
)


Expand All @@ -326,9 +327,9 @@ def test_invalid_operation_action():
)
# Trigger rule and test the condition
tc_callback.on_log(args=test_data.args, state=test_data.state, control=control)
assert (
str(exception_handler.value).strip("'")
== "Invalid operation hfcontrols.missingaction for control loss-controller-invalid-operation-action"
assert str(exception_handler.value).strip("'") == (
"Invalid operation hfcontrols.missingaction"
+ " for control loss-controller-invalid-operation-action"
)


Expand All @@ -352,3 +353,18 @@ def test_invalid_metric():
str(exception_handler.value).strip("'")
== "Undefined metric handler MissingMetricClass"
)


def test_unavailable_metric():
"""Tests the invalid metric scenario in the controller. Uses:
`examples/trainer-controller-configs/loss_invalid_metric.yaml`
"""
test_data = _setup_data()
tc_callback = tc.TrainerControllerCallback(
td.TRAINER_CONFIG_TEST_UNAVAILABLE_METRIC_YAML
)
control = TrainerControl(should_training_stop=False)
# Trigger on_init_end to perform registration of handlers to events
tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control)
# Trigger rule and test the condition
tc_callback.on_step_end(args=test_data.args, state=test_data.state, control=control)
Loading

0 comments on commit 06e8cbc

Please sign in to comment.