Skip to content

Commit

Permalink
Fixes for preset loading
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Dec 21, 2024
1 parent f5ca987 commit 3c4f23b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
16 changes: 10 additions & 6 deletions keras_hub/src/models/efficientnet/efficientnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,17 @@ def __init__(
):
num_stacks = len(stackwise_kernel_sizes)
if "depth_coefficient" in kwargs:
stackwise_depth_coefficients = [
kwargs.pop("depth_coefficient")
] * num_stacks
depth_coefficient = kwargs.pop("depth_coefficient")
if not isinstance(depth_coefficient, (list, tuple)):
stackwise_depth_coefficients = [depth_coefficient] * num_stacks
else:
stackwise_depth_coefficients = depth_coefficient
if "width_coefficient" in kwargs:
stackwise_width_coefficients = [
kwargs.pop("width_coefficient")
] * num_stacks
width_coefficient = kwargs.pop("width_coefficient")
if not isinstance(width_coefficient, (list, tuple)):
stackwise_width_coefficients = [width_coefficient] * num_stacks
else:
stackwise_width_coefficients = width_coefficient

image_input = keras.layers.Input(shape=input_shape)

Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
"params": 558837760,
"path": "xlm_roberta",
},
"kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_large_multi/2",
"kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_large_multi/3",
},
}
2 changes: 2 additions & 0 deletions keras_hub/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def print_msg(message, line_break=True):
logging.info(message)


# Register twice for backwards compat.
@keras.saving.register_keras_serializable(package="keras_hub")
@keras.saving.register_keras_serializable(package="keras_nlp")
def gelu_approximate(x):
return keras.activations.gelu(x, approximate=True)

Expand Down

0 comments on commit 3c4f23b

Please sign in to comment.