Skip to content

Commit

Permalink
Some routine cleanup while writing some new tools for checkpoint admin
Browse files Browse the repository at this point in the history
- Remove broken test in preset_utils we don't ever run
- Move load_serialized_object to our preset loading class
  (for consistency)
- Move all admin related tooling to a dedicated folder in tools/
- Remove some no longer used scripts.
  • Loading branch information
mattdangerw committed Dec 19, 2024
1 parent 08a4681 commit c0d86c1
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 137 deletions.
32 changes: 16 additions & 16 deletions keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,16 +454,6 @@ def load_json(preset, config_file=CONFIG_FILE):
return config


def load_serialized_object(config, **kwargs):
# `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
# Ensure that `dtype` is properly configured.
dtype = kwargs.pop("dtype", None)
config = set_dtype_in_config(config, dtype)

config["config"] = {**config["config"], **kwargs}
return keras.saving.deserialize_keras_object(config)


def check_config_class(config):
"""Validate a preset is being loaded on the correct class."""
registered_name = config["registered_name"]
Expand Down Expand Up @@ -631,26 +621,26 @@ def check_backbone_class(self):
return check_config_class(self.config)

def load_backbone(self, cls, load_weights, **kwargs):
backbone = load_serialized_object(self.config, **kwargs)
backbone = self._load_serialized_object(self.config, **kwargs)
if load_weights:
jax_memory_cleanup(backbone)
backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
return backbone

def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
tokenizer_config = load_json(self.preset, config_file)
tokenizer = load_serialized_object(tokenizer_config, **kwargs)
tokenizer = self._load_serialized_object(tokenizer_config, **kwargs)
if hasattr(tokenizer, "load_preset_assets"):
tokenizer.load_preset_assets(self.preset)
return tokenizer

def load_audio_converter(self, cls, **kwargs):
converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE)
return load_serialized_object(converter_config, **kwargs)
return self._load_serialized_object(converter_config, **kwargs)

def load_image_converter(self, cls, **kwargs):
converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
return load_serialized_object(converter_config, **kwargs)
return self._load_serialized_object(converter_config, **kwargs)

def load_task(self, cls, load_weights, load_task_weights, **kwargs):
# If there is no `task.json` or it's for the wrong class delegate to the
Expand All @@ -671,7 +661,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
backbone_config = task_config["config"]["backbone"]["config"]
backbone_config = {**backbone_config, **backbone_kwargs}
task_config["config"]["backbone"]["config"] = backbone_config
task = load_serialized_object(task_config, **kwargs)
task = self._load_serialized_object(task_config, **kwargs)
if task.preprocessor and hasattr(
task.preprocessor, "load_preset_assets"
):
Expand Down Expand Up @@ -699,11 +689,20 @@ def load_preprocessor(
if not issubclass(check_config_class(preprocessor_json), cls):
return super().load_preprocessor(cls, **kwargs)
# We found a `preprocessing.json` with a complete config for our class.
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
preprocessor = self._load_serialized_object(preprocessor_json, **kwargs)
if hasattr(preprocessor, "load_preset_assets"):
preprocessor.load_preset_assets(self.preset)
return preprocessor

def _load_serialized_object(self, config, **kwargs):
# `dtype` in config might be a serialized `DTypePolicy` or
# `DTypePolicyMap`. Ensure that `dtype` is properly configured.
dtype = kwargs.pop("dtype", None)
config = set_dtype_in_config(config, dtype)

config["config"] = {**config["config"], **kwargs}
return keras.saving.deserialize_keras_object(config)


class KerasPresetSaver:
def __init__(self, preset_dir):
Expand Down Expand Up @@ -787,6 +786,7 @@ def _save_metadata(self, layer):
tasks = list_subclasses(Task)
tasks = filter(lambda x: x.backbone_cls is type(layer), tasks)
tasks = [task.__base__.__name__ for task in tasks]
tasks = sorted(tasks)

keras_version = keras.version() if hasattr(keras, "version") else None
metadata = {
Expand Down
17 changes: 0 additions & 17 deletions keras_hub/src/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from keras_hub.src.models.bert.bert_backbone import BertBackbone
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.keras_utils import has_quantization_support
from keras_hub.src.utils.preset_utils import CONFIG_FILE
from keras_hub.src.utils.preset_utils import load_serialized_object
from keras_hub.src.utils.preset_utils import upload_preset


Expand Down Expand Up @@ -88,18 +86,3 @@ def test_upload_with_invalid_json(self):
# Verify error handling.
with self.assertRaisesRegex(ValueError, "is an invalid json"):
upload_preset("kaggle://test/test/test", local_preset_dir)

@parameterized.named_parameters(
("gemma2_2b_en", "gemma2_2b_en", "bfloat16", False),
("llama2_7b_en_int8", "llama2_7b_en_int8", "bfloat16", True),
)
@pytest.mark.extra_large
def test_load_serialized_object(self, preset, dtype, is_quantized):
if is_quantized and not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

model = load_serialized_object(preset, dtype=dtype)
if is_quantized:
self.assertEqual(model.dtype_policy.name, "map_bfloat16")
else:
self.assertEqual(model.dtype_policy.name, "bfloat16")
File renamed without changes.
File renamed without changes.
File renamed without changes.
104 changes: 0 additions & 104 deletions tools/convert_legacy_presets.py

This file was deleted.

0 comments on commit c0d86c1

Please sign in to comment.