Skip to content

Commit

Permalink
Refactoring adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
luozhouyang committed Dec 23, 2020
1 parent 9c02e11 commit 755531b
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 55 deletions.
26 changes: 19 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ model.summary()

## Load other pretrained models

If you want to load pretraine models using other implementationds, whose config and trainable weights are a little different with previous, you can subclass `AbstractAdapter` to adapte these models:
If you want to load models pretrained by other implementationds, whose config and trainable weights are a little different from previous, you can subclass `AbstractAdapter` to adapte these models:

```python
from transformers_keras.adapters import AbstractAdapter
Expand All @@ -106,18 +106,30 @@ from transformers_keras import Bert, Albert
# load custom bert models
class MyBertAdapter(AbstractAdapter):

def adapte(self, pretrained_model_dir, **kwargs):
# you can refer to `transformers_keras.adapters.bert_adapter`
pass
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass

def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass

bert = Bert.from_pretrained('/path/to/your/bert/model', adapter=MyBertAdapter())

# or, load custom albert models
class MyAlbertAdapter(AbstractAdapter):

def adapte(self, pretrained_model_dir, **kwargs):
# you can refer to `transformers_keras.adapters.albert_adapter`
pass
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass

def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass

albert = Albert.from_pretrained('/path/to/your/albert/model', adapter=MyAlbertAdapter())
```
25 changes: 13 additions & 12 deletions transformers_keras/adapters/abstract_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

def zip_weights(model, ckpt, variables_mapping, verbose=True):
weights, values, names = [], [], []
# for k, v in variables_mapping.items():
# print('{} \t -> {}'.format(k, v))
for w in model.trainable_weights:
names.append(w.name)
weights.append(w)
Expand All @@ -31,26 +29,29 @@ def zip_weights(model, ckpt, variables_mapping, verbose=True):
return mapped_values


def parse_pretrained_model_files(pretrain_model_dir):
def parse_pretrained_model_files(pretrained_model_dir):
config_file, ckpt, vocab = None, None, None
pretrain_model_dir = os.path.abspath(pretrain_model_dir)
if not os.path.exists(pretrain_model_dir):
logging.info('pretrain model dir: {} is not exists.'.format(pretrain_model_dir))
pretrained_model_dir = os.path.abspath(pretrained_model_dir)
if not os.path.exists(pretrained_model_dir):
logging.info('pretrain model dir: {} is not exists.'.format(pretrained_model_dir))
return config_file, ckpt, vocab
for f in os.listdir(pretrain_model_dir):
for f in os.listdir(pretrained_model_dir):
if str(f).endswith('config.json'):
config_file = os.path.join(pretrain_model_dir, f)
config_file = os.path.join(pretrained_model_dir, f)
if 'vocab' in str(f):
vocab = os.path.join(pretrain_model_dir, f)
vocab = os.path.join(pretrained_model_dir, f)
if 'ckpt' in str(f):
n = '.'.join(str(f).split('.')[:-1])
ckpt = os.path.join(pretrain_model_dir, n)
ckpt = os.path.join(pretrained_model_dir, n)
return config_file, ckpt, vocab



class AbstractAdapter(abc.ABC):

def adapte(self, pretrained_model_dir, **kwargs):
@abc.abstractmethod
def adapte_config(self, config_file, **kwargs):
raise NotImplementedError()

@abc.abstractmethod
def adapte_weights(self, model, config, ckpt, **kwargs):
raise NotImplementedError()
27 changes: 15 additions & 12 deletions transformers_keras/adapters/albert_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,12 @@

import tensorflow as tf

from .abstract_adapter import AbstractAdapter, parse_pretrained_model_files
from .abstract_adapter import AbstractAdapter, zip_weights


class AlbertAdapter(AbstractAdapter):

def adapte(self, pretrained_model_dir, **kwargs):
config_file, ckpt, vocab_file = parse_pretrained_model_files(pretrained_model_dir)
model_config = self._adapte_config(config_file)
name_mapping = self._adapte_variables(
num_groups=model_config['num_groups'],
num_layers_each_group=model_config['num_layers_each_group'])
return model_config, name_mapping, ckpt, vocab_file

def _adapte_config(self, config_file):
def adapte_config(self, config_file, **kwargs):
with open(config_file, mode='rt', encoding='utf8') as fin:
config = json.load(fin)

Expand All @@ -39,7 +31,19 @@ def _adapte_config(self, config_file):
}
return model_config

def _adapte_variables(self, num_groups, num_layers_each_group):
def adapte_weights(self, model, config, ckpt, **kwargs):
# mapping weight names
weights_mapping = self._mapping_weight_names(config['num_groups'], config['num_layers_each_group'])
# zip weights and its' values
zipped_weights = zip_weights(
model,
ckpt,
weights_mapping,
verbose=kwargs.get('verbose', True))
# set values to weights
tf.keras.backend.batch_set_value(zipped_weights)

def _mapping_weight_names(self, num_groups, num_layers_each_group):
mapping = {}

# embedding
Expand Down Expand Up @@ -98,4 +102,3 @@ def _adapte_variables(self, num_groups, num_layers_each_group):
mapping[k] = v

return mapping

35 changes: 19 additions & 16 deletions transformers_keras/adapters/bert_adapter.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
import os
import json
import logging
import os

import tensorflow as tf

from .abstract_adapter import AbstractAdapter
from .abstract_adapter import parse_pretrained_model_files
from .abstract_adapter import AbstractAdapter, zip_weights


class BertAdapter(AbstractAdapter):
"""Adapte pretrained models to newly build BERT model."""

def adapte(self, pretrained_model_dir, **kwargs):
config_file, ckpt, vocab_file = parse_pretrained_model_files(pretrained_model_dir)
model_config = self.adapte_config(config_file)
variables_mapping = self.adapte_variables(model_config['num_layers'])
return model_config, variables_mapping, ckpt, vocab_file


def adapte_config(self, config_file):
def adapte_config(self, config_file, **kwrgs):
with open(config_file, mode='rt', encoding='utf8') as fin:
config = json.load(fin)

Expand All @@ -37,10 +28,22 @@ def adapte_config(self, config_file):
}
return model_config

def adapte_variables(self, num_layers=12):
def adapte_weights(self, model, config, ckpt, **kwargs):
# mapping weight names
weights_mapping = self._mapping_weight_names(config['num_layers'])
# zip weight names and values
zipped_weights = zip_weights(
model,
ckpt,
weights_mapping,
verbose=kwargs.get('verbose', True))
# set values to weights
tf.keras.backend.batch_set_value(zipped_weights)

def _mapping_weight_names(self, num_layers=12):
mapping = {}
# embedding

# embedding
mapping.update({
'bert/embedding/weight:0': 'bert/embeddings/word_embeddings',
'bert/embedding/token_type_embedding/embeddings:0': 'bert/embeddings/token_type_embeddings',
Expand All @@ -49,7 +52,7 @@ def adapte_variables(self, num_layers=12):
'bert/embedding/layer_norm/beta:0': 'bert/embeddings/LayerNorm/beta',
})

# encoder
# encoder
model_prefix = 'bert/encoder/layer_{}'
for i in range(num_layers):
encoder_prefix = 'bert/encoder/layer_{}/'.format(i)
Expand Down
8 changes: 4 additions & 4 deletions transformers_keras/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import tensorflow as tf

from transformers_keras.adapters.abstract_adapter import zip_weights
from transformers_keras.adapters import parse_pretrained_model_files
from transformers_keras.adapters.albert_adapter import AlbertAdapter

from .layers import MultiHeadAttention
Expand Down Expand Up @@ -341,13 +341,13 @@ def dummy_inputs(self):

@classmethod
def from_pretrained(cls, pretrained_model_dir, adapter=None, verbose=True, **kwargs):
config_file, ckpt, vocab_file = parse_pretrained_model_files(pretrained_model_dir)
if not adapter:
adapter = AlbertAdapter()
model_config, name_mapping, ckpt, vocab_file = adapter.adapte(pretrained_model_dir)
model_config = adapter.adapte_config(config_file, **kwargs)
model = cls(**model_config)
model(model.dummy_inputs())
weights_values = zip_weights(model, ckpt, name_mapping, verbose=verbose)
tf.keras.backend.batch_set_value(weights_values)
adapter.adapte_weights(model, model_config, ckpt, **kwargs)
return model


Expand Down
8 changes: 4 additions & 4 deletions transformers_keras/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import tensorflow as tf

from transformers_keras.adapters.abstract_adapter import zip_weights
from transformers_keras.adapters import parse_pretrained_model_files
from transformers_keras.adapters.bert_adapter import BertAdapter

from .layers import MultiHeadAttention
Expand Down Expand Up @@ -238,13 +238,13 @@ def dummy_inputs(self):

@classmethod
def from_pretrained(cls, pretrained_model_dir, adapter=None, verbose=True, **kwargs):
config_file, ckpt, vocab_file = parse_pretrained_model_files(pretrained_model_dir)
if not adapter:
adapter = BertAdapter()
model_config, name_mapping, ckpt, vocab_file = adapter.adapte(pretrained_model_dir)
model_config = adapter.adapte_config(config_file, **kwargs)
model = cls(**model_config)
model(model.dummy_inputs())
weights_values = zip_weights(model, ckpt, name_mapping, verbose=verbose)
tf.keras.backend.batch_set_value(weights_values)
adapter.adapte_weights(model, model_config, ckpt, **kwargs)
return model


Expand Down

0 comments on commit 755531b

Please sign in to comment.