From b6a96a65a88b43a3493a628c1db0e0618044456b Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Wed, 29 Nov 2023 17:52:21 -0800 Subject: [PATCH] Switch all preset to the new Kaggle format These are not uploaded to Kaggle just yet, but will be shortly. --- keras_nlp/models/albert/albert_presets.py | 76 +--- keras_nlp/models/backbone.py | 32 +- keras_nlp/models/bart/bart_presets.py | 33 +- keras_nlp/models/bert/bert_presets.py | 187 +--------- .../models/deberta_v3/deberta_v3_presets.py | 80 +---- .../models/distil_bert/distil_bert_presets.py | 51 +-- keras_nlp/models/f_net/f_net_presets.py | 30 +- keras_nlp/models/gpt2/gpt2_presets.py | 85 +---- keras_nlp/models/opt/opt_presets.py | 68 +--- keras_nlp/models/preprocessor.py | 17 +- keras_nlp/models/roberta/roberta_presets.py | 34 +- keras_nlp/models/t5/t5_backbone.py | 2 + keras_nlp/models/t5/t5_presets.py | 110 +----- keras_nlp/models/task.py | 40 +-- keras_nlp/models/whisper/whisper_presets.py | 337 +----------------- .../models/xlm_roberta/xlm_roberta_presets.py | 30 +- keras_nlp/tokenizers/byte_pair_tokenizer.py | 37 +- .../tokenizers/sentence_piece_tokenizer.py | 30 +- keras_nlp/tokenizers/word_piece_tokenizer.py | 30 +- keras_nlp/utils/preset_utils.py | 17 +- 20 files changed, 92 insertions(+), 1234 deletions(-) diff --git a/keras_nlp/models/albert/albert_presets.py b/keras_nlp/models/albert/albert_presets.py index 34126f52cd..eb163a64bf 100644 --- a/keras_nlp/models/albert/albert_presets.py +++ b/keras_nlp/models/albert/albert_presets.py @@ -26,24 +26,7 @@ "path": "albert", "model_card": "https://github.com/google-research/albert/blob/master/README.md", }, - "config": { - "vocabulary_size": 30000, - "num_layers": 12, - "num_heads": 12, - "num_groups": 1, - "num_inner_repetitions": 1, - "embedding_dim": 128, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.0, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_base_en_uncased/v1/model.h5", - "weights_hash": "b83ccf3418dd84adc569324183176813", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_base_en_uncased/v1/vocab.spm", - "spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5", + "kaggle_handle": "gs://keras-nlp-kaggle/albert_base_en_uncased", }, "albert_large_en_uncased": { "metadata": { @@ -56,24 +39,7 @@ "path": "albert", "model_card": "https://github.com/google-research/albert/blob/master/README.md", }, - "config": { - "vocabulary_size": 30000, - "num_layers": 24, - "num_heads": 16, - "num_groups": 1, - "num_inner_repetitions": 1, - "embedding_dim": 128, - "hidden_dim": 1024, - "intermediate_dim": 4096, - "dropout": 0, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_large_en_uncased/v1/model.h5", - "weights_hash": "c7754804efb245f06dd6e7ced32e082c", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_large_en_uncased/v1/vocab.spm", - "spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5", + "kaggle_handle": "gs://keras-nlp-kaggle/albert_large_en_uncased", }, "albert_extra_large_en_uncased": { "metadata": { @@ -86,24 +52,7 @@ "path": "albert", "model_card": "https://github.com/google-research/albert/blob/master/README.md", }, - "config": { - "vocabulary_size": 30000, - "num_layers": 24, - "num_heads": 16, - "num_groups": 1, - "num_inner_repetitions": 1, - "embedding_dim": 128, - "hidden_dim": 2048, - "intermediate_dim": 8192, - "dropout": 0, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_large_en_uncased/v1/model.h5", - "weights_hash": "713209be8aadfa614fd79f18c9aeb16d", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_large_en_uncased/v1/vocab.spm", - "spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5", + "kaggle_handle": "gs://keras-nlp-kaggle/albert_extra_large_en_uncased", }, "albert_extra_extra_large_en_uncased": { "metadata": { @@ -116,23 +65,6 @@ "path": "albert", "model_card": "https://github.com/google-research/albert/blob/master/README.md", }, - "config": { - "vocabulary_size": 30000, - "num_layers": 12, - "num_heads": 64, - "num_groups": 1, - "num_inner_repetitions": 1, - "embedding_dim": 128, - "hidden_dim": 4096, - "intermediate_dim": 16384, - "dropout": 0, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_extra_large_en_uncased/v1/model.h5", - "weights_hash": "a835177b692fb6a82139f94c66db2f22", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/albert_extra_extra_large_en_uncased/v1/vocab.spm", - "spm_proto_hash": "73e62ff8e90f951f24c8b907913039a5", + "kaggle_handle": "gs://keras-nlp-kaggle/albert_extra_extra_large_en_uncased", }, } diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 7ddfeb36da..9b8f9a5a96 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - from keras_nlp.backend import keras from keras_nlp.utils.preset_utils import check_preset_class from keras_nlp.utils.preset_utils import load_from_preset @@ -68,31 +66,6 @@ def from_config(cls, config): def presets(cls): return {} - @classmethod - def _legacy_from_preset( - cls, - preset, - load_weights=True, - **kwargs, - ): - metadata = cls.presets[preset] - config = metadata["config"] - model = cls.from_config({**config, **kwargs}) - - if not load_weights: - return model - - filename = os.path.basename(metadata["weights_url"]) - weights = keras.utils.get_file( - filename, - metadata["weights_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["weights_hash"], - ) - - model.load_weights(weights) - return model - @classmethod def from_preset( cls, @@ -121,9 +94,10 @@ def from_preset( ) ``` """ - # TODO: delete me! + # We support short IDs for official presets, e.g. `"bert_base_en"`. + # Map these to a Kaggle Models handle. if preset in cls.presets: - return cls._legacy_from_preset(preset, **kwargs) + preset = cls.presets[preset]["kaggle_handle"] check_preset_class(preset, cls) return load_from_preset( diff --git a/keras_nlp/models/bart/bart_presets.py b/keras_nlp/models/bart/bart_presets.py index aa06254c10..d5547b37da 100644 --- a/keras_nlp/models/bart/bart_presets.py +++ b/keras_nlp/models/bart/bart_presets.py @@ -25,22 +25,7 @@ "path": "bart", "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md", }, - "config": { - "vocabulary_size": 50265, - "num_layers": 6, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 1024, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/model.h5", - "weights_hash": "5b59403f0cafafbd89680e0785791163", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/vocab.json", - "vocabulary_hash": "be4d3c6f3f5495426b2c03b334334354", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/bart_base_en", }, "bart_large_en": { "metadata": { @@ -62,13 +47,7 @@ "dropout": 0.1, "max_sequence_length": 1024, }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/model.h5", - "weights_hash": "6bfe7e591af8c5699ce6f9f18753af9a", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/vocab.json", - "vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/bart_large_en", }, "bart_large_en_cnn": { "metadata": { @@ -90,12 +69,6 @@ "dropout": 0.1, "max_sequence_length": 1024, }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en_cnn/v1/model.h5", - "weights_hash": "99782ecd9365956f016096fef9afd62c", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en_cnn/v1/vocab.json", - "vocabulary_hash": "be4d3c6f3f5495426b2c03b334334354", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en_cnn/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/bart_large_en_cnn", }, } diff --git a/keras_nlp/models/bert/bert_presets.py b/keras_nlp/models/bert/bert_presets.py index 7a3bbdce63..6919d2b566 100644 --- a/keras_nlp/models/bert/bert_presets.py +++ b/keras_nlp/models/bert/bert_presets.py @@ -27,23 +27,7 @@ "path": "bert", "model_card": "https://github.com/google-research/bert/blob/master/README.md", }, - "config": { - "vocabulary_size": 30522, - "num_layers": 2, - "num_heads": 2, - "hidden_dim": 128, - "intermediate_dim": 512, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": { - "lowercase": True, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bert_tiny_en_uncased/v1/model.h5", - "weights_hash": "c2b29fcbf8f814a0812e4ab89ef5c068", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bert_tiny_en_uncased/v1/vocab.txt", - "vocabulary_hash": "64800d5d8528ce344256daf115d4965e", + "kaggle_handle": "gs://keras-nlp-kaggle/bert_tiny_en_uncased", }, "bert_small_en_uncased": { "metadata": { @@ -56,23 +40,7 @@ "path": "bert", "model_card": "https://github.com/google-research/bert/blob/master/README.md", }, - "config": { - "vocabulary_size": 30522, - "num_layers": 4, - "num_heads": 8, - "hidden_dim": 512, - "intermediate_dim": 2048, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": { - "lowercase": True, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bert_small_en_uncased/v1/model.h5", - "weights_hash": "08632c9479b034f342ba2c2b7afba5f7", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bert_small_en_uncased/v1/vocab.txt", - "vocabulary_hash": "64800d5d8528ce344256daf115d4965e", + "kaggle_handle": "gs://keras-nlp-kaggle/bert_small_en_uncased", }, "bert_medium_en_uncased": { "metadata": { @@ -85,23 +53,7 @@ "path": "bert", "model_card": "https://github.com/google-research/bert/blob/master/README.md", }, - "config": { - "vocabulary_size": 30522, - "num_layers": 8, - "num_heads": 8, - "hidden_dim": 512, - "intermediate_dim": 2048, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": { - "lowercase": True, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bert_medium_en_uncased/v1/model.h5", - "weights_hash": "bb990e1184ec6b6185450c73833cd661", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bert_medium_en_uncased/v1/vocab.txt", - "vocabulary_hash": "64800d5d8528ce344256daf115d4965e", + "kaggle_handle": "gs://keras-nlp-kaggle/bert_medium_en_uncased", }, "bert_base_en_uncased": { "metadata": { @@ -114,23 +66,7 @@ "path": "bert", "model_card": "https://github.com/google-research/bert/blob/master/README.md", }, - "config": { - "vocabulary_size": 30522, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": { - "lowercase": True, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bert_base_en_uncased/v1/model.h5", - "weights_hash": "9b2b2139f221988759ac9cdd17050b31", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bert_base_en_uncased/v1/vocab.txt", - "vocabulary_hash": "64800d5d8528ce344256daf115d4965e", + "kaggle_handle": "gs://keras-nlp-kaggle/bert_base_en_uncased", }, "bert_base_en": { "metadata": { @@ -143,23 +79,7 @@ "path": "bert", "model_card": "https://github.com/google-research/bert/blob/master/README.md", }, - "config": { - "vocabulary_size": 28996, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": { - "lowercase": False, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bert_base_en/v1/model.h5", - "weights_hash": "f94a6cb012e18f4fb8ec92abb91864e9", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bert_base_en/v1/vocab.txt", - "vocabulary_hash": "bb6ca9b42e790e5cd986bbb16444d0e0", + "kaggle_handle": "gs://keras-nlp-kaggle/bert_base_en", }, "bert_base_zh": { "metadata": { @@ -171,23 +91,7 @@ "path": "bert", "model_card": "https://github.com/google-research/bert/blob/master/README.md", }, - "config": { - "vocabulary_size": 21128, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": { - "lowercase": False, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bert_base_zh/v1/model.h5", - "weights_hash": "79afa421e386076e62ab42dad555ab0c", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bert_base_zh/v1/vocab.txt", - "vocabulary_hash": "3b5b76c4aef48ecf8cb3abaafe960f09", + "kaggle_handle": "gs://keras-nlp-kaggle/bert_base_zh", }, "bert_base_multi": { "metadata": { @@ -199,23 +103,7 @@ "path": "bert", "model_card": "https://github.com/google-research/bert/blob/master/README.md", }, - "config": { - "vocabulary_size": 119547, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": { - "lowercase": False, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bert_base_multi/v1/model.h5", - "weights_hash": "b0631cec0a1f2513c6cfd75ba29c33aa", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bert_base_multi/v1/vocab.txt", - "vocabulary_hash": "d9d865138d17f1958502ed060ecfeeb6", + "kaggle_handle": "gs://keras-nlp-kaggle/bert_base_multi", }, "bert_large_en_uncased": { "metadata": { @@ -228,23 +116,7 @@ "path": "bert", "model_card": "https://github.com/google-research/bert/blob/master/README.md", }, - "config": { - "vocabulary_size": 30522, - "num_layers": 24, - "num_heads": 16, - "hidden_dim": 1024, - "intermediate_dim": 4096, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": { - "lowercase": True, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bert_large_en_uncased/v1/model.h5", - "weights_hash": "cc5cacc9565ef400ee4376105f40ddae", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bert_large_en_uncased/v1/vocab.txt", - "vocabulary_hash": "64800d5d8528ce344256daf115d4965e", + "kaggle_handle": "gs://keras-nlp-kaggle/bert_large_en_uncased", }, "bert_large_en": { "metadata": { @@ -257,23 +129,7 @@ "path": "bert", "model_card": "https://github.com/google-research/bert/blob/master/README.md", }, - "config": { - "vocabulary_size": 28996, - "num_layers": 24, - "num_heads": 16, - "hidden_dim": 1024, - "intermediate_dim": 4096, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 2, - }, - "preprocessor_config": { - "lowercase": False, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bert_large_en/v1/model.h5", - "weights_hash": "8b8ab82290bbf4f8db87d4f100648890", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bert_large_en/v1/vocab.txt", - "vocabulary_hash": "bb6ca9b42e790e5cd986bbb16444d0e0", + "kaggle_handle": "gs://keras-nlp-kaggle/bert_large_en", }, } @@ -288,29 +144,6 @@ "path": "bert", "model_card": "https://github.com/google-research/bert/blob/master/README.md", }, - "config": { - "backbone": { - "class_name": "keras_nlp>BertBackbone", - "config": { - "vocabulary_size": 30522, - "hidden_dim": 128, - "intermediate_dim": 512, - "num_layers": 2, - "num_heads": 2, - "max_sequence_length": 512, - "num_segments": 2, - "dropout": 0.1, - }, - }, - "num_classes": 2, - "dropout": 0.1, - }, - "preprocessor_config": { - "lowercase": True, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bert_tiny_en_uncased_sst2/v1/model.h5", - "weights_hash": "1f9c2d59f9e229e08f3fbd44239cfb0b", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bert_tiny_en_uncased_sst2/v1/vocab.txt", - "vocabulary_hash": "64800d5d8528ce344256daf115d4965e", + "kaggle_handle": "gs://keras-nlp-kaggle/bert_tiny_en_uncased_sst2", } } diff --git a/keras_nlp/models/deberta_v3/deberta_v3_presets.py b/keras_nlp/models/deberta_v3/deberta_v3_presets.py index f5df6cb599..771d7ad9c5 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_presets.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_presets.py @@ -25,21 +25,7 @@ "path": "deberta_v3", "model_card": "https://huggingface.co/microsoft/deberta-v3-xsmall", }, - "config": { - "vocabulary_size": 128100, - "num_layers": 12, - "num_heads": 6, - "hidden_dim": 384, - "intermediate_dim": 1536, - "dropout": 0.1, - "max_sequence_length": 512, - "bucket_size": 256, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/deberta_v3_extra_small_en/v1/model.h5", - "weights_hash": "d8e10327107e5c5e20b45548a5028619", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/deberta_v3_extra_small_en/v1/vocab.spm", - "spm_proto_hash": "1613fcbf3b82999c187b09c9db79b568", + "kaggle_handle": "gs://keras-nlp-kaggle/deberta_v3_extra_small_en", }, "deberta_v3_small_en": { "metadata": { @@ -52,21 +38,7 @@ "path": "deberta_v3", "model_card": "https://huggingface.co/microsoft/deberta-v3-small", }, - "config": { - "vocabulary_size": 128100, - "num_layers": 6, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - "bucket_size": 256, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/deberta_v3_small_en/v1/model.h5", - "weights_hash": "84118eb7c5a735f2061ecccaf71bb888", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/deberta_v3_small_en/v1/vocab.spm", - "spm_proto_hash": "1613fcbf3b82999c187b09c9db79b568", + "kaggle_handle": "gs://keras-nlp-kaggle/deberta_v3_small_en", }, "deberta_v3_base_en": { "metadata": { @@ -79,21 +51,7 @@ "path": "deberta_v3", "model_card": "https://huggingface.co/microsoft/deberta-v3-base", }, - "config": { - "vocabulary_size": 128100, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - "bucket_size": 256, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/deberta_v3_base_en/v1/model.h5", - "weights_hash": "cebce044aeed36aec9b94e3b8a255430", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/deberta_v3_base_en/v1/vocab.spm", - "spm_proto_hash": "1613fcbf3b82999c187b09c9db79b568", + "kaggle_handle": "gs://keras-nlp-kaggle/deberta_v3_base_en", }, "deberta_v3_large_en": { "metadata": { @@ -106,21 +64,7 @@ "path": "deberta_v3", "model_card": "https://huggingface.co/microsoft/deberta-v3-large", }, - "config": { - "vocabulary_size": 128100, - "num_layers": 24, - "num_heads": 16, - "hidden_dim": 1024, - "intermediate_dim": 4096, - "dropout": 0.1, - "max_sequence_length": 512, - "bucket_size": 256, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/deberta_v3_large_en/v1/model.h5", - "weights_hash": "bce7690f358a9e39304f8c0ebc71a745", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/deberta_v3_large_en/v1/vocab.spm", - "spm_proto_hash": "1613fcbf3b82999c187b09c9db79b568", + "kaggle_handle": "gs://keras-nlp-kaggle/deberta_v3_large_en", }, "deberta_v3_base_multi": { "metadata": { @@ -133,20 +77,6 @@ "path": "deberta_v3", "model_card": "https://huggingface.co/microsoft/mdeberta-v3-base", }, - "config": { - "vocabulary_size": 251000, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - "bucket_size": 256, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/deberta_v3_base_multi/v1/model.h5", - "weights_hash": "26e5a824b26afd2ee336835bd337bbeb", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/deberta_v3_base_multi/v1/vocab.spm", - "spm_proto_hash": "b4ca07289eac48600b29529119d565e2", + "kaggle_handle": "gs://keras-nlp-kaggle/deberta_v3_base_multi", }, } diff --git a/keras_nlp/models/distil_bert/distil_bert_presets.py b/keras_nlp/models/distil_bert/distil_bert_presets.py index 3f939fb6da..b2a99ef688 100644 --- a/keras_nlp/models/distil_bert/distil_bert_presets.py +++ b/keras_nlp/models/distil_bert/distil_bert_presets.py @@ -26,22 +26,7 @@ "path": "distil_bert", "model_card": "https://huggingface.co/distilbert-base-uncased", }, - "config": { - "vocabulary_size": 30522, - "num_layers": 6, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - }, - "preprocessor_config": { - "lowercase": True, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/distil_bert_base_en_uncased/v1/model.h5", - "weights_hash": "6625a649572e74086d74c46b8d0b0da3", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/distil_bert_base_en_uncased/v1/vocab.txt", - "vocabulary_hash": "64800d5d8528ce344256daf115d4965e", + "kaggle_handle": "gs://keras-nlp-kaggle/distil_bert_base_en_uncased", }, "distil_bert_base_en": { "metadata": { @@ -55,22 +40,7 @@ "path": "distil_bert", "model_card": "https://huggingface.co/distilbert-base-cased", }, - "config": { - "vocabulary_size": 28996, - "num_layers": 6, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - }, - "preprocessor_config": { - "lowercase": False, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/distil_bert_base_en/v1/model.h5", - "weights_hash": "fa36aa6865978efbf85a5c8264e5eb57", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/distil_bert_base_en/v1/vocab.txt", - "vocabulary_hash": "bb6ca9b42e790e5cd986bbb16444d0e0", + "kaggle_handle": "gs://keras-nlp-kaggle/distil_bert_base_en", }, "distil_bert_base_multi": { "metadata": { @@ -82,21 +52,6 @@ "path": "distil_bert", "model_card": "https://huggingface.co/distilbert-base-multilingual-cased", }, - "config": { - "vocabulary_size": 119547, - "num_layers": 6, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - }, - "preprocessor_config": { - "lowercase": False, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/distil_bert_base_multi/v1/model.h5", - "weights_hash": "c0f11095e2a6455bd3b1a6d14800a7fa", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/distil_bert_base_multi/v1/vocab.txt", - "vocabulary_hash": "d9d865138d17f1958502ed060ecfeeb6", + "kaggle_handle": "gs://keras-nlp-kaggle/distil_bert_base_multi", }, } diff --git a/keras_nlp/models/f_net/f_net_presets.py b/keras_nlp/models/f_net/f_net_presets.py index b3df5f8e2c..48cc9827b4 100644 --- a/keras_nlp/models/f_net/f_net_presets.py +++ b/keras_nlp/models/f_net/f_net_presets.py @@ -25,20 +25,7 @@ "path": "f_net", "model_card": "https://github.com/google-research/google-research/blob/master/f_net/README.md", }, - "config": { - "vocabulary_size": 32000, - "num_layers": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 4, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/f_net_base_en/v1/model.h5", - "weights_hash": "35db90842b85a985a0e54c86c00746fe", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/f_net_base_en/v1/vocab.spm", - "spm_proto_hash": "71c5f4610bef1daf116998a113a01f3d", + "kaggle_handle": "gs://keras-nlp-kaggle/f_net_base_en", }, "f_net_large_en": { "metadata": { @@ -51,19 +38,6 @@ "path": "f_net", "model_card": "https://github.com/google-research/google-research/blob/master/f_net/README.md", }, - "config": { - "vocabulary_size": 32000, - "num_layers": 24, - "hidden_dim": 1024, - "intermediate_dim": 4096, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 4, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/f_net_large_en/v1/model.h5", - "weights_hash": "7ae4a3faa67ff054f8cecffb5619f779", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/f_net_large_en/v1/vocab.spm", - "spm_proto_hash": "71c5f4610bef1daf116998a113a01f3d", + "kaggle_handle": "gs://keras-nlp-kaggle/f_net_large_en", }, } diff --git a/keras_nlp/models/gpt2/gpt2_presets.py b/keras_nlp/models/gpt2/gpt2_presets.py index 7101bdb104..e5e546a92a 100644 --- a/keras_nlp/models/gpt2/gpt2_presets.py +++ b/keras_nlp/models/gpt2/gpt2_presets.py @@ -26,22 +26,7 @@ "path": "gpt2", "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md", }, - "config": { - "vocabulary_size": 50257, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 1024, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_base_en/v1/model.h5", - "weights_hash": "f4ea6e1b214516dd7de452461ee6e16e", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_base_en/v1/vocab.json", - "vocabulary_hash": "dffec25a898b1f5e569bec4dffd7e5c0", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_base_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/gpt2_base_en", }, "gpt2_medium_en": { "metadata": { @@ -54,22 +39,7 @@ "path": "gpt2", "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md", }, - "config": { - "vocabulary_size": 50257, - "num_layers": 24, - "num_heads": 16, - "hidden_dim": 1024, - "intermediate_dim": 4096, - "dropout": 0.1, - "max_sequence_length": 1024, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_medium_en/v1/model.h5", - "weights_hash": "580ff9b79c04fc90e6d6f47e975c5afe", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_medium_en/v1/vocab.json", - "vocabulary_hash": "dffec25a898b1f5e569bec4dffd7e5c0", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_medium_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/gpt2_medium_en", }, "gpt2_large_en": { "metadata": { @@ -82,22 +52,7 @@ "path": "gpt2", "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md", }, - "config": { - "vocabulary_size": 50257, - "num_layers": 36, - "num_heads": 20, - "hidden_dim": 1280, - "intermediate_dim": 5120, - "dropout": 0.1, - "max_sequence_length": 1024, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_large_en/v1/model.h5", - "weights_hash": "67957cb3dfc9e965960dabe068811e1a", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_large_en/v1/vocab.json", - "vocabulary_hash": "dffec25a898b1f5e569bec4dffd7e5c0", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_large_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/gpt2_large_en", }, "gpt2_extra_large_en": { "metadata": { @@ -110,22 +65,7 @@ "path": "gpt2", "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md", }, - "config": { - "vocabulary_size": 50257, - "num_layers": 48, - "num_heads": 25, - "hidden_dim": 1600, - "intermediate_dim": 6400, - "dropout": 0.1, - "max_sequence_length": 1024, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_extra_large_en/v1/model.h5", - "weights_hash": "d093c1ee0d9705d845c0190909aa2917", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_extra_large_en/v1/vocab.json", - "vocabulary_hash": "dffec25a898b1f5e569bec4dffd7e5c0", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_extra_large_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/gpt2_extra_large_en", }, "gpt2_base_en_cnn_dailymail": { "metadata": { @@ -137,21 +77,6 @@ "official_name": "GPT-2", "path": "gpt2", }, - "config": { - "vocabulary_size": 50257, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 1024, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_base_en_news/v1/model.h5", - "weights_hash": "09d86ca6e1b4213886b720a1392f2a70", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_base_en_news/v1/vocab.json", - "vocabulary_hash": "dffec25a898b1f5e569bec4dffd7e5c0", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/gpt2_base_en_news/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/gpt2_base_en_cnn_dailymail", }, } diff --git a/keras_nlp/models/opt/opt_presets.py b/keras_nlp/models/opt/opt_presets.py index 7af2641138..3ca0fd7b32 100644 --- a/keras_nlp/models/opt/opt_presets.py +++ b/keras_nlp/models/opt/opt_presets.py @@ -26,22 +26,7 @@ "path": "opt", "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md", }, - "config": { - "vocabulary_size": 50272, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 2048, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/opt_125m_en/v1/model.h5", - "weights_hash": "63e444998982e48da4a1a3970f4c6203", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/opt_125m_en/v1/vocab.json", - "vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/opt_125m_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/opt_125m_en", }, # We skip the 350m checkpoint because it does not match the structure of # other checkpoints. @@ -56,22 +41,7 @@ "path": "opt", "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md", }, - "config": { - "vocabulary_size": 50272, - "num_layers": 24, - "num_heads": 32, - "hidden_dim": 2048, - "intermediate_dim": 8192, - "dropout": 0.1, - "max_sequence_length": 2048, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/opt_1.3b_en/v1/model.h5", - "weights_hash": "0365ac8483e99a912c9770521909ecce", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/opt_1.3b_en/v1/vocab.json", - "vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/opt_1.3b_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/opt_1.3b_en", }, "opt_2.7b_en": { "metadata": { @@ -84,22 +54,7 @@ "path": "opt", "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md", }, - "config": { - "vocabulary_size": 50272, - "num_layers": 32, - "num_heads": 32, - "hidden_dim": 2560, - "intermediate_dim": 10240, - "dropout": 0.1, - "max_sequence_length": 2048, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/opt_2.7b_en/v1/model.h5", - "weights_hash": "af56da9206a95b9287356955c5bc14e7", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/opt_2.7b_en/v1/vocab.json", - "vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/opt_2.7b_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/opt_2.7b_en", }, "opt_6.7b_en": { "metadata": { @@ -112,21 +67,6 @@ "path": "opt", "model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md", }, - "config": { - "vocabulary_size": 50272, - "num_layers": 32, - "num_heads": 32, - "hidden_dim": 4096, - "intermediate_dim": 16384, - "dropout": 0.1, - "max_sequence_length": 2048, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/opt_6.7b_en/v1/model.h5", - "weights_hash": "543120fbe601b70e6ec04cc909781e21", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/opt_6.7b_en/v1/vocab.json", - "vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/opt_6.7b_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/opt_6.7b_en", }, } diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index 5e54b2d7e3..16a65e57c2 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -64,18 +64,6 @@ def tokenizer_cls(cls): def presets(cls): return {} - @classmethod - def _legacy_from_preset( - cls, - preset, - **kwargs, - ): - tokenizer = cls.tokenizer_cls.from_preset(preset) - return cls( - tokenizer=tokenizer, - **kwargs, - ) - @classmethod def from_preset( cls, @@ -95,9 +83,10 @@ def from_preset( ) ``` """ - # TODO: delete me! + # We support short IDs for official presets, e.g. `"bert_base_en"`. + # Map these to a Kaggle Models handle. if preset in cls.presets: - return cls._legacy_from_preset(preset, **kwargs) + preset = cls.presets[preset]["kaggle_handle"] config_file = "tokenizer.json" check_preset_class(preset, cls.tokenizer_cls, config_file=config_file) diff --git a/keras_nlp/models/roberta/roberta_presets.py b/keras_nlp/models/roberta/roberta_presets.py index f098bed5d7..a57f7cf479 100644 --- a/keras_nlp/models/roberta/roberta_presets.py +++ b/keras_nlp/models/roberta/roberta_presets.py @@ -25,22 +25,7 @@ "path": "roberta", "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md", }, - "config": { - "vocabulary_size": 50265, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/roberta_base_en/v1/model.h5", - "weights_hash": "958eede1c7edaa9308e027be18fde7a8", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/roberta_base_en/v1/vocab.json", - "vocabulary_hash": "be4d3c6f3f5495426b2c03b334334354", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/roberta_base_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/roberta_base_en", }, "roberta_large_en": { "metadata": { @@ -53,21 +38,6 @@ "path": "roberta", "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md", }, - "config": { - "vocabulary_size": 50265, - "num_layers": 24, - "num_heads": 16, - "hidden_dim": 1024, - "intermediate_dim": 4096, - "dropout": 0.1, - "max_sequence_length": 512, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/roberta_large_en/v1/model.h5", - "weights_hash": "1978b864c317a697fe62a894d3664f14", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/roberta_large_en/v1/vocab.json", - "vocabulary_hash": "be4d3c6f3f5495426b2c03b334334354", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/roberta_large_en/v1/merges.txt", - "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + "kaggle_handle": "gs://keras-nlp-kaggle/roberta_large_en", }, } diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index 314a5d68df..6e76094d71 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -222,6 +222,7 @@ def __init__( self.activation = keras.activations.get(activation) self.key_value_dim = key_value_dim self.dropout = dropout + self.use_gated_activation = use_gated_activation self.layer_norm_epsilon = layer_norm_epsilon self.tie_embedding_weights = tie_embedding_weights self.token_embedding = token_embedding_layer @@ -238,6 +239,7 @@ def get_config(self): "activation": keras.activations.serialize(self.activation), "key_value_dim": self.key_value_dim, "dropout": self.dropout, + "use_gated_activation": self.use_gated_activation, "layer_norm_epsilon": self.layer_norm_epsilon, "tie_embedding_weights": self.tie_embedding_weights, } diff --git a/keras_nlp/models/t5/t5_presets.py b/keras_nlp/models/t5/t5_presets.py index dd2bea7a4e..d5c502c5ba 100644 --- a/keras_nlp/models/t5/t5_presets.py +++ b/keras_nlp/models/t5/t5_presets.py @@ -25,24 +25,7 @@ "path": "t5", "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md", }, - "config": { - "vocabulary_size": 32128, - "num_layers": 6, - "num_heads": 8, - "hidden_dim": 512, - "intermediate_dim": 2048, - "key_value_dim": 64, - "dropout": 0.1, - "activation": "relu", - "use_gated_activation": False, - "layer_norm_epsilon": 1e-06, - "tie_embedding_weights": True, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_small_multi/v1/model.weights.h5", - "weights_hash": "2e10b5f72405d464ee55026b07e60741", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_small_multi/v1/vocab.spm", - "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", + "kaggle_handle": "gs://keras-nlp-kaggle/t5_small_multi", }, "t5_base_multi": { "metadata": { @@ -55,23 +38,7 @@ "path": "t5", "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md", }, - "config": { - "vocabulary_size": 32128, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "activation": "relu", - "use_gated_activation": False, - "layer_norm_epsilon": 1e-06, - "tie_embedding_weights": True, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_base_multi/v1/model.weights.h5", - "weights_hash": "bed6ef276cfe83d1323467051211978d", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_base_multi/v1/vocab.spm", - "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", + "kaggle_handle": "gs://keras-nlp-kaggle/t5_base_multi", }, "t5_large_multi": { "metadata": { @@ -84,23 +51,7 @@ "path": "t5", "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md", }, - "config": { - "vocabulary_size": 32128, - "num_layers": 24, - "num_heads": 16, - "hidden_dim": 1024, - "intermediate_dim": 4096, - "dropout": 0.1, - "activation": "relu", - "use_gated_activation": False, - "layer_norm_epsilon": 1e-06, - "tie_embedding_weights": True, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_large_multi/v1/model.weights.h5", - "weights_hash": "7854a05c2e6812899bf6f0f104792cda", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_large_multi/v1/vocab.spm", - "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", + "kaggle_handle": "gs://keras-nlp-kaggle/t5_large_multi", }, "flan_small_multi": { "metadata": { @@ -113,24 +64,7 @@ "path": "t5", "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md", }, - "config": { - "vocabulary_size": 32128, - "num_layers": 8, - "num_heads": 6, - "hidden_dim": 512, - "intermediate_dim": 1024, - "key_value_dim": 64, - "dropout": 0.1, - "activation": "keras_nlp>gelu_approximate", - "use_gated_activation": True, - "layer_norm_epsilon": 1e-06, - "tie_embedding_weights": False, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_small_multi/v1/model.weights.h5", - "weights_hash": "aa0fbaddb1759ef313bbc4f9e4f1e197", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_small_multi/v1/vocab.spm", - "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", + "kaggle_handle": "gs://keras-nlp-kaggle/flan_small_multi", }, "flan_base_multi": { "metadata": { @@ -143,23 +77,7 @@ "path": "t5", "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md", }, - "config": { - "vocabulary_size": 32128, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 2048, - "dropout": 0.1, - "activation": "keras_nlp>gelu_approximate", - "use_gated_activation": True, - "layer_norm_epsilon": 1e-06, - "tie_embedding_weights": False, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_base_multi/v1/model.weights.h5", - "weights_hash": "84a10bec83fd093931bb2a6264115d31", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_base_multi/v1/vocab.spm", - "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", + "kaggle_handle": "gs://keras-nlp-kaggle/flan_base_multi", }, "flan_large_multi": { "metadata": { @@ -172,22 +90,6 @@ "path": "t5", "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md", }, - "config": { - "vocabulary_size": 32128, - "num_layers": 24, - "num_heads": 16, - "hidden_dim": 1024, - "intermediate_dim": 2816, - "dropout": 0.1, - "activation": "keras_nlp>gelu_approximate", - "use_gated_activation": True, - "layer_norm_epsilon": 1e-06, - "tie_embedding_weights": False, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_large_multi/v1/model.weights.h5", - "weights_hash": "513f530ce790efa7e261c0ef965f3697", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_large_multi/v1/vocab.spm", - "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", + "kaggle_handle": "gs://keras-nlp-kaggle/flan_large_multi", }, } diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 88f74b9a0d..97f06d0b1d 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - from rich import console as rich_console from rich import markup from rich import table as rich_table @@ -150,39 +148,6 @@ def preprocessor_cls(cls): def presets(cls): return {} - @classmethod - def _legacy_from_preset( - cls, - preset, - load_weights=True, - **kwargs, - ): - if "preprocessor" not in kwargs: - kwargs["preprocessor"] = cls.preprocessor_cls.from_preset(preset) - - # Check if preset is backbone-only model - if preset in cls.backbone_cls.presets: - backbone = cls.backbone_cls.from_preset(preset, load_weights) - return cls(backbone, **kwargs) - - # Otherwise must be one of class presets - metadata = cls.presets[preset] - config = metadata["config"] - model = cls.from_config({**config, **kwargs}) - - if not load_weights: - return model - - weights = keras.utils.get_file( - "model.h5", - metadata["weights_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["weights_hash"], - ) - - model.load_weights(weights) - return model - @classmethod def from_preset( cls, @@ -209,9 +174,10 @@ def from_preset( ) ``` """ - # TODO: delete me! + # We support short IDs for official presets, e.g. `"bert_base_en"`. + # Map these to a Kaggle Models handle. if preset in cls.presets: - return cls._legacy_from_preset(preset, load_weights, **kwargs) + preset = cls.presets[preset]["kaggle_handle"] preset_cls = check_preset_class(preset, (cls, cls.backbone_cls)) diff --git a/keras_nlp/models/whisper/whisper_presets.py b/keras_nlp/models/whisper/whisper_presets.py index 8ec5a7353d..81c10ce870 100644 --- a/keras_nlp/models/whisper/whisper_presets.py +++ b/keras_nlp/models/whisper/whisper_presets.py @@ -11,123 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -MULTILINGUAL_SPECIAL_TOKENS = { - "<|startoftranscript|>": 50258, - "<|endoftext|>": 50257, - "<|notimestamps|>": 50363, - "<|translate|>": 50359, - "<|transcribe|>": 50358, -} - -ENGLISH_SPECIAL_TOKENS = { - "<|startoftranscript|>": 50257, - "<|endoftext|>": 50256, - "<|notimestamps|>": 50362, - "<|translate|>": 50358, - "<|transcribe|>": 50357, -} - -LANGUAGE_TOKENS = { - "<|af|>": 50327, - "<|am|>": 50334, - "<|ar|>": 50272, - "<|as|>": 50350, - "<|az|>": 50304, - "<|ba|>": 50355, - "<|be|>": 50330, - "<|bg|>": 50292, - "<|bn|>": 50302, - "<|bo|>": 50347, - "<|br|>": 50309, - "<|bs|>": 50315, - "<|ca|>": 50270, - "<|cs|>": 50283, - "<|cy|>": 50297, - "<|da|>": 50285, - "<|de|>": 50261, - "<|el|>": 50281, - "<|en|>": 50259, - "<|es|>": 50262, - "<|et|>": 50307, - "<|eu|>": 50310, - "<|fa|>": 50300, - "<|fi|>": 50277, - "<|fo|>": 50338, - "<|fr|>": 50265, - "<|gl|>": 50319, - "<|gu|>": 50333, - "<|haw|>": 50352, - "<|ha|>": 50354, - "<|he|>": 50279, - "<|hi|>": 50276, - "<|hr|>": 50291, - "<|ht|>": 50339, - "<|hu|>": 50286, - "<|hy|>": 50312, - "<|id|>": 50275, - "<|is|>": 50311, - "<|it|>": 50274, - "<|ja|>": 50266, - "<|jw|>": 50356, - "<|ka|>": 50329, - "<|kk|>": 50316, - "<|km|>": 50323, - "<|kn|>": 50306, - "<|ko|>": 50264, - "<|la|>": 50294, - "<|lb|>": 50345, - "<|ln|>": 50353, - "<|lo|>": 50336, - "<|lt|>": 50293, - "<|lv|>": 50301, - "<|mg|>": 50349, - "<|mi|>": 50295, - "<|mk|>": 50308, - "<|ml|>": 50296, - "<|mn|>": 50314, - "<|mr|>": 50320, - "<|ms|>": 50282, - "<|mt|>": 50343, - "<|my|>": 50346, - "<|ne|>": 50313, - "<|nl|>": 50271, - "<|nn|>": 50342, - "<|no|>": 50288, - "<|oc|>": 50328, - "<|pa|>": 50321, - "<|pl|>": 50269, - "<|ps|>": 50340, - "<|pt|>": 50267, - "<|ro|>": 50284, - "<|ru|>": 50263, - "<|sa|>": 50344, - "<|sd|>": 50332, - "<|si|>": 50322, - "<|sk|>": 50298, - "<|sl|>": 50305, - "<|sn|>": 50324, - "<|so|>": 50326, - "<|sq|>": 50317, - "<|sr|>": 50303, - "<|su|>": 50357, - "<|sv|>": 50273, - "<|sw|>": 50318, - "<|ta|>": 50287, - "<|te|>": 50299, - "<|tg|>": 50331, - "<|th|>": 50289, - "<|tk|>": 50341, - "<|tl|>": 50348, - "<|tr|>": 50268, - "<|tt|>": 50351, - "<|uk|>": 50280, - "<|ur|>": 50290, - "<|uz|>": 50337, - "<|vi|>": 50278, - "<|yi|>": 50335, - "<|yo|>": 50325, - "<|zh|>": 50260, -} # Metadata for loading pretrained model weights. backbone_presets = { @@ -142,27 +25,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "config": { - "vocabulary_size": 51864, - "num_layers": 4, - "num_heads": 6, - "hidden_dim": 384, - "intermediate_dim": 1536, - "num_mels": 80, - "dropout": 0.0, - "max_encoder_sequence_length": 3000, - "max_decoder_sequence_length": 448, - }, - "preprocessor_config": { - "special_tokens": ENGLISH_SPECIAL_TOKENS, - "language_tokens": None, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/whisper_tiny_en/v1/model.h5", - "weights_hash": "3dc3768ac48ec90b1029fbf52ffbacc7", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/whisper_tiny_en/v1/vocab.json", - "vocabulary_hash": "22377f841debacb023848b3468ea3281", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/whisper_tiny_en/v1/merges.txt", - "merges_hash": "093ecf3f30371012f2e96fcfb10ea6ab", + "kaggle_handle": "gs://keras-nlp-kaggle/whisper_tiny_en", }, "whisper_base_en": { "metadata": { @@ -175,27 +38,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "config": { - "vocabulary_size": 51864, - "num_layers": 6, - "num_heads": 8, - "hidden_dim": 512, - "intermediate_dim": 2048, - "num_mels": 80, - "dropout": 0.0, - "max_encoder_sequence_length": 3000, - "max_decoder_sequence_length": 448, - }, - "preprocessor_config": { - "special_tokens": ENGLISH_SPECIAL_TOKENS, - "language_tokens": None, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_en/v1/model.h5", - "weights_hash": "799d3c143993d42f7446bafbc0f46d7d", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_en/v1/vocab.json", - "vocabulary_hash": "22377f841debacb023848b3468ea3281", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_en/v1/merges.txt", - "merges_hash": "093ecf3f30371012f2e96fcfb10ea6ab", + "kaggle_handle": "gs://keras-nlp-kaggle/whisper_base_en", }, "whisper_small_en": { "metadata": { @@ -208,27 +51,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "config": { - "vocabulary_size": 51864, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "num_mels": 80, - "dropout": 0.0, - "max_encoder_sequence_length": 3000, - "max_decoder_sequence_length": 448, - }, - "preprocessor_config": { - "special_tokens": ENGLISH_SPECIAL_TOKENS, - "language_tokens": None, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_en/v1/model.h5", - "weights_hash": "b75a89225e20019d85ff5f1c362f8a49", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_en/v1/vocab.json", - "vocabulary_hash": "22377f841debacb023848b3468ea3281", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_en/v1/merges.txt", - "merges_hash": "093ecf3f30371012f2e96fcfb10ea6ab", + "kaggle_handle": "gs://keras-nlp-kaggle/whisper_small_en", }, "whisper_medium_en": { "metadata": { @@ -241,27 +64,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "config": { - "vocabulary_size": 51864, - "num_layers": 24, - "num_heads": 16, - "hidden_dim": 1024, - "intermediate_dim": 4096, - "num_mels": 80, - "dropout": 0.0, - "max_encoder_sequence_length": 3000, - "max_decoder_sequence_length": 448, - }, - "preprocessor_config": { - "special_tokens": ENGLISH_SPECIAL_TOKENS, - "language_tokens": None, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/whisper_medium_en/v1/model.h5", - "weights_hash": "107184882d1cc65926815e4cc50dc5f3", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/whisper_medium_en/v1/vocab.json", - "vocabulary_hash": "22377f841debacb023848b3468ea3281", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/whisper_medium_en/v1/merges.txt", - "merges_hash": "093ecf3f30371012f2e96fcfb10ea6ab", + "kaggle_handle": "gs://keras-nlp-kaggle/whisper_medium_en", }, "whisper_tiny_multi": { "metadata": { @@ -274,27 +77,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "config": { - "vocabulary_size": 51865, - "num_layers": 4, - "num_heads": 6, - "hidden_dim": 384, - "intermediate_dim": 1536, - "num_mels": 80, - "dropout": 0.0, - "max_encoder_sequence_length": 3000, - "max_decoder_sequence_length": 448, - }, - "preprocessor_config": { - "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, - "language_tokens": LANGUAGE_TOKENS, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/whisper_tiny_multi/v1/model.h5", - "weights_hash": "b1279a81001ad5eb35970d1aea706396", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/whisper_tiny_multi/v1/vocab.json", - "vocabulary_hash": "1b87ed3e3ecd9ccfdca74e64cbe81d68", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/whisper_tiny_multi/v1/merges.txt", - "merges_hash": "c7f01d4100f6211417988889bf35ccd8", + "kaggle_handle": "gs://keras-nlp-kaggle/whisper_tiny_multi", }, "whisper_base_multi": { "metadata": { @@ -307,27 +90,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "config": { - "vocabulary_size": 51865, - "num_layers": 6, - "num_heads": 8, - "hidden_dim": 512, - "intermediate_dim": 2048, - "num_mels": 80, - "dropout": 0.0, - "max_encoder_sequence_length": 3000, - "max_decoder_sequence_length": 448, - }, - "preprocessor_config": { - "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, - "language_tokens": LANGUAGE_TOKENS, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_multi/v1/model.h5", - "weights_hash": "5208396e2d5efac43114a4a3d4f583ab", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_multi/v1/vocab.json", - "vocabulary_hash": "1b87ed3e3ecd9ccfdca74e64cbe81d68", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_multi/v1/merges.txt", - "merges_hash": "c7f01d4100f6211417988889bf35ccd8", + "kaggle_handle": "gs://keras-nlp-kaggle/whisper_base_multi", }, "whisper_small_multi": { "metadata": { @@ -340,27 +103,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "config": { - "vocabulary_size": 51865, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "num_mels": 80, - "dropout": 0.0, - "max_encoder_sequence_length": 3000, - "max_decoder_sequence_length": 448, - }, - "preprocessor_config": { - "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, - "language_tokens": LANGUAGE_TOKENS, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_multi/v1/model.h5", - "weights_hash": "c90c6a895e522056b77b924b6e907ed8", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_multi/v1/vocab.json", - "vocabulary_hash": "1b87ed3e3ecd9ccfdca74e64cbe81d68", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/whisper_base_multi/v1/merges.txt", - "merges_hash": "c7f01d4100f6211417988889bf35ccd8", + "kaggle_handle": "gs://keras-nlp-kaggle/whisper_small_multi", }, "whisper_medium_multi": { "metadata": { @@ -373,27 +116,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "config": { - "vocabulary_size": 51865, - "num_layers": 24, - "num_heads": 16, - "hidden_dim": 1024, - "intermediate_dim": 4096, - "num_mels": 80, - "dropout": 0.0, - "max_encoder_sequence_length": 3000, - "max_decoder_sequence_length": 448, - }, - "preprocessor_config": { - "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, - "language_tokens": LANGUAGE_TOKENS, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/whisper_medium_multi/v1/model.h5", - "weights_hash": "6f993f732fe397e9c5e3a96a9505a3a9", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/whisper_medium_multi/v1/vocab.json", - "vocabulary_hash": "1b87ed3e3ecd9ccfdca74e64cbe81d68", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/whisper_medium_multi/v1/merges.txt", - "merges_hash": "c7f01d4100f6211417988889bf35ccd8", + "kaggle_handle": "gs://keras-nlp-kaggle/whisper_medium_multi", }, "whisper_large_multi": { "metadata": { @@ -406,27 +129,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "config": { - "vocabulary_size": 51865, - "num_layers": 32, - "num_heads": 20, - "hidden_dim": 1280, - "intermediate_dim": 5120, - "num_mels": 80, - "dropout": 0.0, - "max_encoder_sequence_length": 3000, - "max_decoder_sequence_length": 448, - }, - "preprocessor_config": { - "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, - "language_tokens": LANGUAGE_TOKENS, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/whisper_large_multi/v1/model.h5", - "weights_hash": "ccab1c93c5739007868ae73fe025806d", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/whisper_large_multi/v1/vocab.json", - "vocabulary_hash": "1b87ed3e3ecd9ccfdca74e64cbe81d68", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/whisper_large_multi/v1/merges.txt", - "merges_hash": "c7f01d4100f6211417988889bf35ccd8", + "kaggle_handle": "gs://keras-nlp-kaggle/whisper_large_multi", }, "whisper_large_multi_v2": { "metadata": { @@ -440,26 +143,6 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "config": { - "vocabulary_size": 51865, - "num_layers": 32, - "num_heads": 20, - "hidden_dim": 1280, - "intermediate_dim": 5120, - "num_mels": 80, - "dropout": 0.0, - "max_encoder_sequence_length": 3000, - "max_decoder_sequence_length": 448, - }, - "preprocessor_config": { - "special_tokens": MULTILINGUAL_SPECIAL_TOKENS, - "language_tokens": LANGUAGE_TOKENS, - }, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/whisper_large_multi_v2/v1/model.h5", - "weights_hash": "ca157162ec9c3329a659388528a3af88", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/whisper_large_multi_v2/v1/vocab.json", - "vocabulary_hash": "1b87ed3e3ecd9ccfdca74e64cbe81d68", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/whisper_large_multi_v2/v1/merges.txt", - "merges_hash": "c7f01d4100f6211417988889bf35ccd8", + "kaggle_handle": "gs://keras-nlp-kaggle/whisper_large_multi_v2", }, } diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_presets.py b/keras_nlp/models/xlm_roberta/xlm_roberta_presets.py index 350c069f1d..5b7a571e48 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_presets.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_presets.py @@ -25,20 +25,7 @@ "path": "xlm_roberta", "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/xlmr/README.md", }, - "config": { - "vocabulary_size": 250002, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/xlm_roberta_base_multi/v1/model.h5", - "weights_hash": "2eb6fcda5a42f0a88056213ba3d93906", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/xlm_roberta_base_multi/v1/vocab.spm", - "spm_proto_hash": "bf25eb5120ad92ef5c7d8596b5dc4046", + "kaggle_handle": "gs://keras-nlp-kaggle/xlm_roberta_base_multi", }, "xlm_roberta_large_multi": { "metadata": { @@ -51,19 +38,6 @@ "path": "xlm_roberta", "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/xlmr/README.md", }, - "config": { - "vocabulary_size": 250002, - "num_layers": 24, - "num_heads": 16, - "hidden_dim": 1024, - "intermediate_dim": 4096, - "dropout": 0.1, - "max_sequence_length": 512, - }, - "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/xlm_roberta_large_multi/v1/model.h5", - "weights_hash": "276211827174b71751f2ce3a89da503a", - "spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/xlm_roberta_large_multi/v1/vocab.spm", - "spm_proto_hash": "bf25eb5120ad92ef5c7d8596b5dc4046", + "kaggle_handle": "gs://keras-nlp-kaggle/xlm_roberta_large_multi", }, } diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 8c90af8558..2471df051c 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -28,7 +28,6 @@ import tensorflow as tf from keras_nlp.api_export import keras_nlp_export -from keras_nlp.backend import keras from keras_nlp.tokenizers import tokenizer from keras_nlp.utils.preset_utils import check_preset_class from keras_nlp.utils.preset_utils import load_from_preset @@ -642,37 +641,6 @@ def get_config(self): def presets(cls): return {} - @classmethod - def _legacy_from_preset( - cls, - preset, - **kwargs, - ): - metadata = cls.presets[preset] - - vocabulary = keras.utils.get_file( - "vocab.txt", - metadata["vocabulary_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["vocabulary_hash"], - ) - merges = keras.utils.get_file( - "merges.txt", - metadata["merges_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["merges_hash"], - ) - - config = metadata["preprocessor_config"] - config.update( - { - "vocabulary": vocabulary, - "merges": merges, - }, - ) - - return cls.from_config({**config, **kwargs}) - @classmethod def from_preset( cls, @@ -696,9 +664,10 @@ def from_preset( tokenizer.detokenize([5, 6, 7, 8, 9]) ``` """ - # TODO: delete me! + # We support short IDs for official presets, e.g. `"bert_base_en"`. + # Map these to a Kaggle Models handle. if preset in cls.presets: - return cls._legacy_from_preset(preset, **kwargs) + preset = cls.presets[preset]["kaggle_handle"] config_file = "tokenizer.json" check_preset_class(preset, cls, config_file=config_file) diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer.py index eb6abb8140..ae655aceb6 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer.py @@ -20,7 +20,6 @@ import tensorflow as tf from keras_nlp.api_export import keras_nlp_export -from keras_nlp.backend import keras from keras_nlp.tokenizers import tokenizer from keras_nlp.utils.preset_utils import check_preset_class from keras_nlp.utils.preset_utils import load_from_preset @@ -263,30 +262,6 @@ def detokenize(self, inputs): def presets(cls): return {} - @classmethod - def _legacy_from_preset( - cls, - preset, - **kwargs, - ): - metadata = cls.presets[preset] - - spm_proto = keras.utils.get_file( - "vocab.spm", - metadata["spm_proto_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["spm_proto_hash"], - ) - - config = metadata["preprocessor_config"] - config.update( - { - "proto": spm_proto, - }, - ) - - return cls.from_config({**config, **kwargs}) - @classmethod def from_preset( cls, @@ -310,9 +285,10 @@ def from_preset( tokenizer.detokenize([5, 6, 7, 8, 9]) ``` """ - + # We support short IDs for official presets, e.g. `"bert_base_en"`. + # Map these to a Kaggle Models handle. if preset in cls.presets: - return cls._legacy_from_preset(preset, **kwargs) + preset = cls.presets[preset]["kaggle_handle"] config_file = "tokenizer.json" check_preset_class(preset, cls, config_file=config_file) diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index 6d1fa8e7f1..4e7b05b230 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -19,7 +19,6 @@ import tensorflow as tf from keras_nlp.api_export import keras_nlp_export -from keras_nlp.backend import keras from keras_nlp.tokenizers import tokenizer from keras_nlp.utils.preset_utils import check_preset_class from keras_nlp.utils.preset_utils import load_from_preset @@ -470,30 +469,6 @@ def detokenize(self, inputs): def presets(cls): return {} - @classmethod - def _legacy_from_preset( - cls, - preset, - **kwargs, - ): - metadata = cls.presets[preset] - - vocabulary = keras.utils.get_file( - "vocab.txt", - metadata["vocabulary_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["vocabulary_hash"], - ) - - config = metadata["preprocessor_config"] - config.update( - { - "vocabulary": vocabulary, - }, - ) - - return cls.from_config({**config, **kwargs}) - @classmethod def from_preset( cls, @@ -517,9 +492,10 @@ def from_preset( tokenizer.detokenize([5, 6, 7, 8, 9]) ``` """ - # TODO: delete me! + # We support short IDs for official presets, e.g. `"bert_base_en"`. + # Map these to a Kaggle Models handle. if preset in cls.presets: - return cls._legacy_from_preset(preset, **kwargs) + preset = cls.presets[preset]["kaggle_handle"] config_file = "tokenizer.json" check_preset_class(preset, cls, config_file=config_file) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index 04ca3a39cd..f2234f615d 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -24,6 +24,7 @@ kagglehub = None KAGGLE_PREFIX = "kaggle://" +GS_PREFIX = "gs://" TOKENIZER_ASSET_DIR = "assets/tokenizer" @@ -51,7 +52,21 @@ def get_file(preset, path): f"Received: preset={preset}" ) return kagglehub.model_download(kaggle_handle, path) - return os.path.join(preset, path) + elif preset.startswith(GS_PREFIX): + url = os.path.join(preset, path) + url = url.replace(GS_PREFIX, "https://storage.googleapis.com/") + subdir = preset.replace(GS_PREFIX, "gs_") + subdir = subdir.replace("/", "_").replace("-", "_") + filename = os.path.basename(path) + subdir = os.path.join(subdir, os.path.dirname(path)) + return keras.utils.get_file( + filename, + url, + cache_subdir=os.path.join("models", subdir), + ) + else: + # Assume a local filepath. + return os.path.join(preset, path) def get_tokenizer(layer):