Skip to content

Commit

Permalink
Add a hacky fix for TF 2.13 and 2.14 weights.h5 loading
Browse files Browse the repository at this point in the history
We have a bug where weights.h5 for a functional model will read and
write to the wrong paths in TF 2.13 and 2.14. We can work around this
for these versions (while thankfully needing none of this for Keras 3).
  • Loading branch information
mattdangerw committed Dec 7, 2023
1 parent 969b8a3 commit 0fec85b
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import datetime
import inspect
import json
import os

Expand Down Expand Up @@ -159,6 +160,21 @@ def save_to_preset(
metadata_file.write(json.dumps(metadata, indent=4))


def legacy_load_weights(layer, weights_path):
# Hacky fix for TensorFlow 2.13 and 2.14 when loading a `.weights.h5` file.
# We find the `Functional` class, and temporarily remove the
# `_layer_checkpoint_dependencies` property, which on older version of
# TensorFlow complete broke the variable paths for functional models.
functional_cls = None
for cls in inspect.getmro(layer.__class__):
if cls.__name__ == "Functional":
functional_cls = cls
property = functional_cls._layer_checkpoint_dependencies
functional_cls._layer_checkpoint_dependencies = None
layer.load_weights(weights_path)
functional_cls._layer_checkpoint_dependencies = property


def load_from_preset(
preset,
load_weights=True,
Expand Down Expand Up @@ -186,7 +202,10 @@ def load_from_preset(
load_weights = load_weights and config["weights"]
if load_weights:
weights_path = get_file(preset, config["weights"])
layer.load_weights(weights_path)
if hasattr(layer, "_layer_checkpoint_dependencies"):
legacy_load_weights(layer, weights_path)
else:
layer.load_weights(weights_path)

return layer

Expand Down

0 comments on commit 0fec85b

Please sign in to comment.