Skip to content

Commit

Permalink
Use subclass checking check_preset_class (keras-team#1344)
Browse files Browse the repository at this point in the history
Not currently needed for anything, just to keep in sync with KerasCV.
  • Loading branch information
mattdangerw committed Dec 7, 2023
1 parent 4eea4f1 commit bcbce12
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
4 changes: 3 additions & 1 deletion keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def check_preset_class(
cls = keras.saving.get_registered_object(config["registered_name"])
if not isinstance(classes, (tuple, list)):
classes = (classes,)
if cls not in classes:
# Allow subclasses for testing a base class, e.g.
# `check_preset_class(preset, Backbone)`
if not any(issubclass(cls, x) for x in classes):
raise ValueError(
f"Unexpected class in preset `'{preset}'`. "
"When calling `from_preset()` on a class object, the preset class "
Expand Down
23 changes: 17 additions & 6 deletions keras_nlp/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
import pytest
from absl.testing import parameterized

from keras_nlp.models import AlbertClassifier
from keras_nlp.models import BertClassifier
from keras_nlp.models import RobertaClassifier
from keras_nlp.models.albert.albert_classifier import AlbertClassifier
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.bert.bert_classifier import BertClassifier
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
from keras_nlp.models.task import Task
from keras_nlp.tests.test_case import TestCase
from keras_nlp.utils import preset_utils
from keras_nlp.utils.preset_utils import check_preset_class
from keras_nlp.utils.preset_utils import load_from_preset
from keras_nlp.utils.preset_utils import save_to_preset


class PresetUtilsTest(TestCase):
Expand All @@ -36,7 +40,7 @@ class PresetUtilsTest(TestCase):
def test_preset_saving(self, cls, preset_name, tokenizer_type):
save_dir = self.get_temp_dir()
model = cls.from_preset(preset_name, num_classes=2)
preset_utils.save_to_preset(model, save_dir)
save_to_preset(model, save_dir)

if tokenizer_type == "bytepair":
vocab_filename = "assets/tokenizer/vocabulary.json"
Expand Down Expand Up @@ -72,7 +76,14 @@ def test_preset_saving(self, cls, preset_name, tokenizer_type):
self.assertEqual(config["weights"], "model.weights.h5")

# Try loading the model from preset directory
restored_model = preset_utils.load_from_preset(save_dir)
self.assertEqual(cls, check_preset_class(save_dir, cls))
self.assertEqual(cls, check_preset_class(save_dir, Task))
with self.assertRaises(ValueError):
# Preset is a subclass of Task, not Backbone.
check_preset_class(save_dir, Backbone)

# Try loading the model from preset directory
restored_model = load_from_preset(save_dir)

train_data = (
["the quick brown fox.", "the slow brown fox."], # Features.
Expand Down

0 comments on commit bcbce12

Please sign in to comment.