diff --git a/README.md b/README.md index 2ff9bbb..085119c 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ model.summary() ## Load other pretrained models -If you want to load pretraine models using other implementationds, whose config and trainable weights are a little different with previous, you can subclass `AbstractAdapter` to adapte these models: +If you want to load models pretrained by other implementationds, whose config and trainable weights are a little different from previous, you can subclass `AbstractAdapter` to adapte these models: ```python from transformers_keras.adapters import AbstractAdapter @@ -106,18 +106,30 @@ from transformers_keras import Bert, Albert # load custom bert models class MyBertAdapter(AbstractAdapter): - def adapte(self, pretrained_model_dir, **kwargs): - # you can refer to `transformers_keras.adapters.bert_adapter` - pass + def adapte_config(self, config_file, **kwargs): + # adapte model config here + # you can refer to `transformers_keras.adapters.bert_adapter` + pass + + def adapte_weights(self, model, config, ckpt, **kwargs): + # adapte model weights here + # you can refer to `transformers_keras.adapters.bert_adapter` + pass bert = Bert.from_pretrained('/path/to/your/bert/model', adapter=MyBertAdapter()) # or, load custom albert models class MyAlbertAdapter(AbstractAdapter): - def adapte(self, pretrained_model_dir, **kwargs): - # you can refer to `transformers_keras.adapters.albert_adapter` - pass + def adapte_config(self, config_file, **kwargs): + # adapte model config here + # you can refer to `transformers_keras.adapters.albert_adapter` + pass + + def adapte_weights(self, model, config, ckpt, **kwargs): + # adapte model weights here + # you can refer to `transformers_keras.adapters.albert_adapter` + pass albert = Albert.from_pretrained('/path/to/your/albert/model', adapter=MyAlbertAdapter()) ``` diff --git a/transformers_keras/adapters/abstract_adapter.py b/transformers_keras/adapters/abstract_adapter.py index 2194a21..2812256 100644 --- a/transformers_keras/adapters/abstract_adapter.py +++ b/transformers_keras/adapters/abstract_adapter.py @@ -7,8 +7,6 @@ def zip_weights(model, ckpt, variables_mapping, verbose=True): weights, values, names = [], [], [] - # for k, v in variables_mapping.items(): - # print('{} \t -> {}'.format(k, v)) for w in model.trainable_weights: names.append(w.name) weights.append(w) @@ -31,26 +29,29 @@ def zip_weights(model, ckpt, variables_mapping, verbose=True): return mapped_values -def parse_pretrained_model_files(pretrain_model_dir): +def parse_pretrained_model_files(pretrained_model_dir): config_file, ckpt, vocab = None, None, None - pretrain_model_dir = os.path.abspath(pretrain_model_dir) - if not os.path.exists(pretrain_model_dir): - logging.info('pretrain model dir: {} is not exists.'.format(pretrain_model_dir)) + pretrained_model_dir = os.path.abspath(pretrained_model_dir) + if not os.path.exists(pretrained_model_dir): + logging.info('pretrain model dir: {} is not exists.'.format(pretrained_model_dir)) return config_file, ckpt, vocab - for f in os.listdir(pretrain_model_dir): + for f in os.listdir(pretrained_model_dir): if str(f).endswith('config.json'): - config_file = os.path.join(pretrain_model_dir, f) + config_file = os.path.join(pretrained_model_dir, f) if 'vocab' in str(f): - vocab = os.path.join(pretrain_model_dir, f) + vocab = os.path.join(pretrained_model_dir, f) if 'ckpt' in str(f): n = '.'.join(str(f).split('.')[:-1]) - ckpt = os.path.join(pretrain_model_dir, n) + ckpt = os.path.join(pretrained_model_dir, n) return config_file, ckpt, vocab - class AbstractAdapter(abc.ABC): - def adapte(self, pretrained_model_dir, **kwargs): + @abc.abstractmethod + def adapte_config(self, config_file, **kwargs): raise NotImplementedError() + @abc.abstractmethod + def adapte_weights(self, model, config, ckpt, **kwargs): + raise NotImplementedError() diff --git a/transformers_keras/adapters/albert_adapter.py b/transformers_keras/adapters/albert_adapter.py index c7928ab..5e58398 100644 --- a/transformers_keras/adapters/albert_adapter.py +++ b/transformers_keras/adapters/albert_adapter.py @@ -4,20 +4,12 @@ import tensorflow as tf -from .abstract_adapter import AbstractAdapter, parse_pretrained_model_files +from .abstract_adapter import AbstractAdapter, zip_weights class AlbertAdapter(AbstractAdapter): - def adapte(self, pretrained_model_dir, **kwargs): - config_file, ckpt, vocab_file = parse_pretrained_model_files(pretrained_model_dir) - model_config = self._adapte_config(config_file) - name_mapping = self._adapte_variables( - num_groups=model_config['num_groups'], - num_layers_each_group=model_config['num_layers_each_group']) - return model_config, name_mapping, ckpt, vocab_file - - def _adapte_config(self, config_file): + def adapte_config(self, config_file, **kwargs): with open(config_file, mode='rt', encoding='utf8') as fin: config = json.load(fin) @@ -39,7 +31,19 @@ def _adapte_config(self, config_file): } return model_config - def _adapte_variables(self, num_groups, num_layers_each_group): + def adapte_weights(self, model, config, ckpt, **kwargs): + # mapping weight names + weights_mapping = self._mapping_weight_names(config['num_groups'], config['num_layers_each_group']) + # zip weights and its' values + zipped_weights = zip_weights( + model, + ckpt, + weights_mapping, + verbose=kwargs.get('verbose', True)) + # set values to weights + tf.keras.backend.batch_set_value(zipped_weights) + + def _mapping_weight_names(self, num_groups, num_layers_each_group): mapping = {} # embedding @@ -98,4 +102,3 @@ def _adapte_variables(self, num_groups, num_layers_each_group): mapping[k] = v return mapping - diff --git a/transformers_keras/adapters/bert_adapter.py b/transformers_keras/adapters/bert_adapter.py index 74572c6..2602ec1 100644 --- a/transformers_keras/adapters/bert_adapter.py +++ b/transformers_keras/adapters/bert_adapter.py @@ -1,24 +1,15 @@ -import os import json import logging +import os import tensorflow as tf -from .abstract_adapter import AbstractAdapter -from .abstract_adapter import parse_pretrained_model_files +from .abstract_adapter import AbstractAdapter, zip_weights class BertAdapter(AbstractAdapter): - """Adapte pretrained models to newly build BERT model.""" - - def adapte(self, pretrained_model_dir, **kwargs): - config_file, ckpt, vocab_file = parse_pretrained_model_files(pretrained_model_dir) - model_config = self.adapte_config(config_file) - variables_mapping = self.adapte_variables(model_config['num_layers']) - return model_config, variables_mapping, ckpt, vocab_file - - def adapte_config(self, config_file): + def adapte_config(self, config_file, **kwrgs): with open(config_file, mode='rt', encoding='utf8') as fin: config = json.load(fin) @@ -37,10 +28,22 @@ def adapte_config(self, config_file): } return model_config - def adapte_variables(self, num_layers=12): + def adapte_weights(self, model, config, ckpt, **kwargs): + # mapping weight names + weights_mapping = self._mapping_weight_names(config['num_layers']) + # zip weight names and values + zipped_weights = zip_weights( + model, + ckpt, + weights_mapping, + verbose=kwargs.get('verbose', True)) + # set values to weights + tf.keras.backend.batch_set_value(zipped_weights) + + def _mapping_weight_names(self, num_layers=12): mapping = {} - - # embedding + + # embedding mapping.update({ 'bert/embedding/weight:0': 'bert/embeddings/word_embeddings', 'bert/embedding/token_type_embedding/embeddings:0': 'bert/embeddings/token_type_embeddings', @@ -49,7 +52,7 @@ def adapte_variables(self, num_layers=12): 'bert/embedding/layer_norm/beta:0': 'bert/embeddings/LayerNorm/beta', }) - # encoder + # encoder model_prefix = 'bert/encoder/layer_{}' for i in range(num_layers): encoder_prefix = 'bert/encoder/layer_{}/'.format(i) diff --git a/transformers_keras/modeling_albert.py b/transformers_keras/modeling_albert.py index 1525401..248c2d8 100644 --- a/transformers_keras/modeling_albert.py +++ b/transformers_keras/modeling_albert.py @@ -4,7 +4,7 @@ import tensorflow as tf -from transformers_keras.adapters.abstract_adapter import zip_weights +from transformers_keras.adapters import parse_pretrained_model_files from transformers_keras.adapters.albert_adapter import AlbertAdapter from .layers import MultiHeadAttention @@ -341,13 +341,13 @@ def dummy_inputs(self): @classmethod def from_pretrained(cls, pretrained_model_dir, adapter=None, verbose=True, **kwargs): + config_file, ckpt, vocab_file = parse_pretrained_model_files(pretrained_model_dir) if not adapter: adapter = AlbertAdapter() - model_config, name_mapping, ckpt, vocab_file = adapter.adapte(pretrained_model_dir) + model_config = adapter.adapte_config(config_file, **kwargs) model = cls(**model_config) model(model.dummy_inputs()) - weights_values = zip_weights(model, ckpt, name_mapping, verbose=verbose) - tf.keras.backend.batch_set_value(weights_values) + adapter.adapte_weights(model, model_config, ckpt, **kwargs) return model diff --git a/transformers_keras/modeling_bert.py b/transformers_keras/modeling_bert.py index f0649e1..fb1c84d 100644 --- a/transformers_keras/modeling_bert.py +++ b/transformers_keras/modeling_bert.py @@ -4,7 +4,7 @@ import tensorflow as tf -from transformers_keras.adapters.abstract_adapter import zip_weights +from transformers_keras.adapters import parse_pretrained_model_files from transformers_keras.adapters.bert_adapter import BertAdapter from .layers import MultiHeadAttention @@ -238,13 +238,13 @@ def dummy_inputs(self): @classmethod def from_pretrained(cls, pretrained_model_dir, adapter=None, verbose=True, **kwargs): + config_file, ckpt, vocab_file = parse_pretrained_model_files(pretrained_model_dir) if not adapter: adapter = BertAdapter() - model_config, name_mapping, ckpt, vocab_file = adapter.adapte(pretrained_model_dir) + model_config = adapter.adapte_config(config_file, **kwargs) model = cls(**model_config) model(model.dummy_inputs()) - weights_values = zip_weights(model, ckpt, name_mapping, verbose=verbose) - tf.keras.backend.batch_set_value(weights_values) + adapter.adapte_weights(model, model_config, ckpt, **kwargs) return model