Skip to content

Commit

Permalink
Add a hacky fix for TF 2.13 and 2.14 weights.h5 loading
Browse files Browse the repository at this point in the history
We have a bug where weights.h5 for a functional model will read and
write to the wrong paths in TF 2.13 and 2.14. We can work around this
for these versions (while thankfully needing none of this for Keras 3).
  • Loading branch information
mattdangerw committed Dec 7, 2023
1 parent 969b8a3 commit 188289f
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import datetime
import inspect
import json
import os

Expand Down Expand Up @@ -159,6 +160,21 @@ def save_to_preset(
metadata_file.write(json.dumps(metadata, indent=4))


def legacy_load_weights(layer, weights_path):
# Hacky fix for TF 2.13 and TF 2.14 when restoring a `.weights.h5` file.
# We find the Functional model class, and temporarily remove the
# _layer_checkpoint_dependencies property, which on older version of
# tensorflow complete broke the variable paths for functional models.
functional_cls = None
for cls in inspect.getmro(layer.__class__):
if cls.__name__ == "Functional":
functional_cls = cls
property = functional_cls._layer_checkpoint_dependencies
functional_cls._layer_checkpoint_dependencies = None
layer.load_weights(weights_path)
functional_cls._layer_checkpoint_dependencies = property


def load_from_preset(
preset,
load_weights=True,
Expand Down Expand Up @@ -186,7 +202,10 @@ def load_from_preset(
load_weights = load_weights and config["weights"]
if load_weights:
weights_path = get_file(preset, config["weights"])
layer.load_weights(weights_path)
if hasattr(layer, "_layer_checkpoint_dependencies"):
legacy_load_weights(layer, weights_path)
else:
layer.load_weights(weights_path)

return layer

Expand Down

0 comments on commit 188289f

Please sign in to comment.