Skip to content

Commit

Permalink
Switch all preset to the new Kaggle format
Browse files Browse the repository at this point in the history
These are not uploaded to Kaggle just yet, but will be shortly.
  • Loading branch information
mattdangerw committed Nov 30, 2023
1 parent ab64de7 commit b6a96a6
Show file tree
Hide file tree
Showing 20 changed files with 92 additions and 1,234 deletions.
76 changes: 4 additions & 72 deletions keras_nlp/models/albert/albert_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,7 @@
"path": "albert",
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
},
"config": {
"vocabulary_size": 30000,
"num_layers": 12,
"num_heads": 12,
"num_groups": 1,
"num_inner_repetitions": 1,
"embedding_dim": 128,
"hidden_dim": 768,
"intermediate_dim": 3072,
"dropout": 0.0,
"max_sequence_length": 512,
"num_segments": 2,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_base_en_uncased/v1/model.h5",
"weights_hash": "b83ccf3418dd84adc569324183176813",
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_base_en_uncased/v1/vocab.spm",
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
"kaggle_handle": "gs://keras-nlp-kaggle/albert_base_en_uncased",
},
"albert_large_en_uncased": {
"metadata": {
Expand All @@ -56,24 +39,7 @@
"path": "albert",
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
},
"config": {
"vocabulary_size": 30000,
"num_layers": 24,
"num_heads": 16,
"num_groups": 1,
"num_inner_repetitions": 1,
"embedding_dim": 128,
"hidden_dim": 1024,
"intermediate_dim": 4096,
"dropout": 0,
"max_sequence_length": 512,
"num_segments": 2,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_large_en_uncased/v1/model.h5",
"weights_hash": "c7754804efb245f06dd6e7ced32e082c",
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_large_en_uncased/v1/vocab.spm",
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
"kaggle_handle": "gs://keras-nlp-kaggle/albert_large_en_uncased",
},
"albert_extra_large_en_uncased": {
"metadata": {
Expand All @@ -86,24 +52,7 @@
"path": "albert",
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
},
"config": {
"vocabulary_size": 30000,
"num_layers": 24,
"num_heads": 16,
"num_groups": 1,
"num_inner_repetitions": 1,
"embedding_dim": 128,
"hidden_dim": 2048,
"intermediate_dim": 8192,
"dropout": 0,
"max_sequence_length": 512,
"num_segments": 2,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_large_en_uncased/v1/model.h5",
"weights_hash": "713209be8aadfa614fd79f18c9aeb16d",
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_large_en_uncased/v1/vocab.spm",
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
"kaggle_handle": "gs://keras-nlp-kaggle/albert_extra_large_en_uncased",
},
"albert_extra_extra_large_en_uncased": {
"metadata": {
Expand All @@ -116,23 +65,6 @@
"path": "albert",
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
},
"config": {
"vocabulary_size": 30000,
"num_layers": 12,
"num_heads": 64,
"num_groups": 1,
"num_inner_repetitions": 1,
"embedding_dim": 128,
"hidden_dim": 4096,
"intermediate_dim": 16384,
"dropout": 0,
"max_sequence_length": 512,
"num_segments": 2,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_extra_large_en_uncased/v1/model.h5",
"weights_hash": "a835177b692fb6a82139f94c66db2f22",
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_extra_large_en_uncased/v1/vocab.spm",
"spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5",
"kaggle_handle": "gs://keras-nlp-kaggle/albert_extra_extra_large_en_uncased",
},
}
32 changes: 3 additions & 29 deletions keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from keras_nlp.backend import keras
from keras_nlp.utils.preset_utils import check_preset_class
from keras_nlp.utils.preset_utils import load_from_preset
Expand Down Expand Up @@ -68,31 +66,6 @@ def from_config(cls, config):
def presets(cls):
return {}

@classmethod
def _legacy_from_preset(
cls,
preset,
load_weights=True,
**kwargs,
):
metadata = cls.presets[preset]
config = metadata["config"]
model = cls.from_config({**config, **kwargs})

if not load_weights:
return model

filename = os.path.basename(metadata["weights_url"])
weights = keras.utils.get_file(
filename,
metadata["weights_url"],
cache_subdir=os.path.join("models", preset),
file_hash=metadata["weights_hash"],
)

model.load_weights(weights)
return model

@classmethod
def from_preset(
cls,
Expand Down Expand Up @@ -121,9 +94,10 @@ def from_preset(
)
```
"""
# TODO: delete me!
# We support short IDs for official presets, e.g. `"bert_base_en"`.
# Map these to a Kaggle Models handle.
if preset in cls.presets:
return cls._legacy_from_preset(preset, **kwargs)
preset = cls.presets[preset]["kaggle_handle"]

check_preset_class(preset, cls)
return load_from_preset(
Expand Down
33 changes: 3 additions & 30 deletions keras_nlp/models/bart/bart_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,7 @@
"path": "bart",
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md",
},
"config": {
"vocabulary_size": 50265,
"num_layers": 6,
"num_heads": 12,
"hidden_dim": 768,
"intermediate_dim": 3072,
"dropout": 0.1,
"max_sequence_length": 1024,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/model.h5",
"weights_hash": "5b59403f0cafafbd89680e0785791163",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/vocab.json",
"vocabulary_hash": "be4d3c6f3f5495426b2c03b334334354",
"merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/merges.txt",
"merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e",
"kaggle_handle": "gs://keras-nlp-kaggle/bart_base_en",
},
"bart_large_en": {
"metadata": {
Expand All @@ -62,13 +47,7 @@
"dropout": 0.1,
"max_sequence_length": 1024,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/model.h5",
"weights_hash": "6bfe7e591af8c5699ce6f9f18753af9a",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/vocab.json",
"vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596",
"merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/merges.txt",
"merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e",
"kaggle_handle": "gs://keras-nlp-kaggle/bart_large_en",
},
"bart_large_en_cnn": {
"metadata": {
Expand All @@ -90,12 +69,6 @@
"dropout": 0.1,
"max_sequence_length": 1024,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en_cnn/v1/model.h5",
"weights_hash": "99782ecd9365956f016096fef9afd62c",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en_cnn/v1/vocab.json",
"vocabulary_hash": "be4d3c6f3f5495426b2c03b334334354",
"merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en_cnn/v1/merges.txt",
"merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e",
"kaggle_handle": "gs://keras-nlp-kaggle/bart_large_en_cnn",
},
}
Loading

0 comments on commit b6a96a6

Please sign in to comment.