Skip to content

Commit

Permalink
add AITM component & upgrade NaryDisEmbedding & fix a few bugs (#487)
Browse files Browse the repository at this point in the history
* add AITM component & upgrade NaryDisEmbedding & fix a few bugs
  • Loading branch information
yangxudong authored Oct 8, 2024
1 parent 696163f commit f403b69
Show file tree
Hide file tree
Showing 25 changed files with 552 additions and 328 deletions.
30 changes: 27 additions & 3 deletions docs/source/component/component.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
| Gate | 门控 | 多个输入的加权求和 | [Cross Decoupling Network](../models/cdn.html#id2) |
| PeriodicEmbedding | 周期激活函数 | 数值特征Embedding | [案例5](backbone.md#dlrm-embedding) |
| AutoDisEmbedding | 自动离散化 | 数值特征Embedding | [dlrm_on_criteo_with_autodis.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_on_criteo_with_autodis.config) |
| NaryDisEmbedding | N进制编码 | 数值特征Embedding | [dlrm_on_criteo_with_narydis.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_on_criteo_with_narydis.config) |
| TextCNN | 文本卷积 | 提取文本序列的特征 | [text_cnn_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/text_cnn_on_movielens.config) |

**备注**:Gate组件的第一个输入是权重向量,后面的输入拼凑成一个列表,权重向量的长度应等于列表的长度
Expand Down Expand Up @@ -47,9 +48,10 @@

## 5. 多目标学习组件

| 类名 | 功能 | 说明 | 示例 |
| ---- | --------------------------- | --------- | ----------------------- |
| MMoE | Multiple Mixture of Experts | MMoE模型的组件 | [案例8](backbone.md#mmoe) |
| 类名 | 功能 | 说明 | 示例 |
| --------- | --------------------------- | --------- | ----------------------------- |
| MMoE | Multiple Mixture of Experts | MMoE模型的组件 | [案例8](backbone.md#mmoe) |
| AITMTower | AITM模型的一个tower | AITM模型的组件 | [AITM](../models/aitm.md#id2) |

## 6. 辅助损失函数组件

Expand Down Expand Up @@ -108,6 +110,20 @@
| output_tensor_list | bool | false | 是否同时输出embedding列表 |
| output_3d_tensor | bool | false | 是否同时输出3d tensor, `output_tensor_list=true`时该参数不生效 |

- NaryDisEmbedding

| 参数 | 类型 | 默认值 | 说明 |
| ------------------ | ------ | ----- | --------------------------------------------------- |
| embedding_dim | uint32 | | embedding维度 |
| carries | list | | N-ary 数值特征需要编码的进制列表 |
| multiplier | float | 1.0 | 针对float类型的特征,放大`multiplier`倍再取整后进行进制编码 |
| intra_ary_pooling | string | sum | 同一进制的不同位的数字embedding如何聚合成最终的embedding, 可选:sum, mean |
| num_replicas | uint32 | 1 | 每个特征输出多少个embedding表征 |
| output_tensor_list | bool | false | 是否同时输出embedding列表 |
| output_3d_tensor | bool | false | 是否同时输出3d tensor, `output_tensor_list=true`时该参数不生效 |

备注:该组件依赖自定义Tensorflow OP,可能在某些版本的TF上无法使用

- TextCNN

| 参数 | 类型 | 默认值 | 说明 |
Expand Down Expand Up @@ -314,6 +330,14 @@ BERT模型结构
| num_expert | uint32 | 0 | expert数量 |
| expert_mlp | MLP | 可选 | expert的mlp参数 |

- AITMTower

| 参数 | 类型 | 默认值 | 说明 |
| ------------- | ------ | ---- | ------------------------------ |
| project_dim | uint32 | 可选 | attention Query, Key, Value的维度 |
| stop_gradient | bool | True | 是否需要停用对依赖的输入的梯度 |
| transfer_mlp | MLP | | transfer的mlp参数 |

## 6. 计算辅助损失函数的组件

- AuxiliaryLoss
Expand Down
59 changes: 49 additions & 10 deletions docs/source/models/aitm.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ model_config {
}
backbone {
blocks {
name: "mlp"
name: "share_bottom"
inputs {
feature_group_name: "all"
}
Expand All @@ -42,6 +42,49 @@ model_config {
}
}
}
blocks {
name: "ctr_tower"
inputs {
block_name: "share_bottom"
}
keras_layer {
class_name: 'MLP'
mlp {
hidden_units: 128
}
}
}
blocks {
name: "cvr_tower"
inputs {
block_name: "share_bottom"
}
keras_layer {
class_name: 'MLP'
mlp {
hidden_units: 128
}
}
}
blocks {
name: "cvr_aitm"
inputs {
block_name: "cvr_tower"
}
inputs {
block_name: "ctr_tower"
}
merge_inputs_into_list: true
keras_layer {
class_name: "AITMTower"
aitm {
transfer_mlp {
hidden_units: 128
}
}
}
}
output_blocks: ["ctr_tower", "cvr_aitm"]
}
model_params {
task_towers {
Expand All @@ -52,9 +95,8 @@ model_config {
auc {}
}
dnn {
hidden_units: [256, 128]
hidden_units: 64
}
use_ait_module: true
weight: 1.0
}
task_towers {
Expand All @@ -70,11 +112,8 @@ model_config {
auc {}
}
dnn {
hidden_units: [256, 128]
hidden_units: 64
}
relation_tower_names: ["ctr"]
use_ait_module: true
ait_project_dim: 128
weight: 1.0
}
l2_regularization: 1e-6
Expand All @@ -95,15 +134,15 @@ model_config {
- name/inputs: 每个`block`有一个唯一的名字(name),并且有一个或多个输入(inputs)和输出
- keras_layer: 加载由`class_name`指定的自定义或系统内置的keras layer,执行一段代码逻辑;[参考文档](../component/backbone.md#keraslayer)
- mlp: MLP模型的参数,详见[参考文档](../component/component.md#id1)
- cvr_aitm: AITMTower组件,该组件块的input的顺序不能乱写,第一个input必须是当前tower的输入,后续的inputs是依赖的前驱模块
- output_blocks: backbone的输出tensor列表,顺序必须与下面`model_params`里配置的任务tower一致

- model_params: AITM相关的参数
- model_params: 多目标建模相关的参数

- task_towers 根据任务数配置task_towers
- tower_name
- dnn deep part的参数配置
- hidden_units: dnn每一层的channel数目,即神经元的数目
- use_ait_module: if true 使用`AITM`模型;否则,使用[DBMTL](dbmtl.md)模型
- ait_project_dim: 每个tower对应的表征向量的维度,一般设为最后一个隐藏的维度即可
- 默认为二分类任务,即num_class默认为1,weight默认为1.0,loss_type默认为CLASSIFICATION,metrics_set为auc
- loss_type: ORDER_CALIBRATE_LOSS 使用目标依赖关系校正预测结果的辅助损失函数,详见原始论文
- 注:label_fields需与task_towers一一对齐。
Expand Down
15 changes: 9 additions & 6 deletions easy_rec/python/layers/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def block_input(self, config, block_outputs, training=None, **kwargs):
try:
output = merge_inputs(inputs, config.input_concat_axis, config.name)
except ValueError as e:
logging.error('merge inputs of block %s failed: %s', config.name,
e.message)
msg = getattr(e, 'message', str(e))
logging.error('merge inputs of block %s failed: %s', config.name, msg)
raise e

if config.HasField('extra_input_fn'):
Expand Down Expand Up @@ -317,7 +317,9 @@ def call(self, is_training, **kwargs):
inputs, _, weights = self._feature_group_inputs[feature_group]
block_outputs[block] = input_fn([inputs, weights], is_training)
else:
inputs = self.block_input(config, block_outputs, is_training, **kwargs)
with tf.name_scope(block + '_input'):
inputs = self.block_input(config, block_outputs, is_training,
**kwargs)
output = self.call_layer(inputs, config, block, is_training, **kwargs)
block_outputs[block] = output

Expand All @@ -340,7 +342,8 @@ def call(self, is_training, **kwargs):
try:
output = merge_inputs(outputs, msg='backbone')
except ValueError as e:
logging.error("merge backbone's output failed: %s", e.message)
msg = getattr(e, 'message', str(e))
logging.error("merge backbone's output failed: %s", msg)
raise e
return output

Expand Down Expand Up @@ -402,8 +405,8 @@ def call_keras_layer(self, inputs, name, training, **kwargs):
try:
output = layer(inputs, training=training, **kwargs)
except Exception as e:
logging.error('call keras layer %s (%s) failed: %s' %
(name, cls, e.message))
msg = getattr(e, 'message', str(e))
logging.error('call keras layer %s (%s) failed: %s' % (name, cls, msg))
raise e
else:
try:
Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/layers/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .mask_net import MaskBlock
from .mask_net import MaskNet
from .multi_head_attention import MultiHeadAttention
from .multi_task import AITMTower
from .multi_task import MMoE
from .numerical_embedding import AutoDisEmbedding
from .numerical_embedding import NaryDisEmbedding
Expand Down
8 changes: 4 additions & 4 deletions easy_rec/python/layers/keras/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def call(self, inputs, mask=None):
return tf.nn.softmax(inputs, axis=self.axis)


def activation_layer(activation):
def activation_layer(activation, name=None):
if activation in ('dice', 'Dice'):
act_layer = Dice()
act_layer = Dice(name=name)
elif isinstance(activation, (str, unicode)):
act_fn = easy_rec.python.utils.activation.get_activation(activation)
act_layer = Activation(act_fn)
act_layer = Activation(act_fn, name=name)
elif issubclass(activation, Layer):
act_layer = activation()
act_layer = activation(name=name)
else:
raise ValueError(
'Invalid activation,found %s.You should use a str or a Activation Layer Class.'
Expand Down
22 changes: 15 additions & 7 deletions easy_rec/python/layers/keras/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class MLP(Layer):

def __init__(self, params, name='mlp', reuse=None, **kwargs):
super(MLP, self).__init__(name=name, **kwargs)
self.layer_name = name
self.layer_name = name # for add to output
params.check_required('hidden_units')
use_bn = params.get_or_default('use_bn', True)
use_final_bn = params.get_or_default('use_final_bn', True)
Expand Down Expand Up @@ -79,14 +79,14 @@ def add_rich_layer(self,
use_bn_after_activation,
name,
l2_reg=None):
act_layer = activation_layer(activation)
act_layer = activation_layer(activation, name='%s/act' % name)
if use_bn and not use_bn_after_activation:
dense = Dense(
units=num_units,
use_bias=use_bias,
kernel_initializer=initializer,
kernel_regularizer=l2_reg,
name=name)
name='%s/dense' % name)
self._sub_layers.append(dense)
bn = tf.keras.layers.BatchNormalization(
name='%s/bn' % name, trainable=True)
Expand All @@ -98,7 +98,7 @@ def add_rich_layer(self,
use_bias=use_bias,
kernel_initializer=initializer,
kernel_regularizer=l2_reg,
name=name)
name='%s/dense' % name)
self._sub_layers.append(dense)
self._sub_layers.append(act_layer)
if use_bn and use_bn_after_activation:
Expand All @@ -117,7 +117,7 @@ def call(self, x, training=None, **kwargs):
cls = layer.__class__.__name__
if cls in ('Dropout', 'BatchNormalization', 'Dice'):
x = layer(x, training=training)
if cls in ('BatchNormalization', 'Dice'):
if cls in ('BatchNormalization', 'Dice') and training:
add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
else:
x = layer(x)
Expand Down Expand Up @@ -183,8 +183,14 @@ class Gate(Layer):
def __init__(self, params, name='gate', reuse=None, **kwargs):
super(Gate, self).__init__(name=name, **kwargs)
self.weight_index = params.get_or_default('weight_index', 0)
if params.has_field('mlp'):
mlp_cfg = Parameter.make_from_pb(params.mlp)
mlp_cfg.l2_regularizer = params.l2_regularizer
self.top_mlp = MLP(mlp_cfg, name='top_mlp')
else:
self.top_mlp = None

def call(self, inputs, **kwargs):
def call(self, inputs, training=None, **kwargs):
assert len(
inputs
) > 1, 'input of Gate layer must be a list containing at least 2 elements'
Expand All @@ -198,6 +204,8 @@ def call(self, inputs, **kwargs):
else:
output += weights[:, j, None] * x
j += 1
if self.top_mlp is not None:
output = self.top_mlp(output, training=training)
return output


Expand Down Expand Up @@ -248,7 +256,7 @@ def call(self, inputs, training=None, **kwargs):
pooled_outputs.append(pooled)
net = self.concat_layer(pooled_outputs)
if self.mlp is not None:
output = self.mlp(net)
output = self.mlp(net, training=training)
else:
output = net
return output
21 changes: 12 additions & 9 deletions easy_rec/python/layers/keras/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ def call(self, inputs, training=None, **kwargs):
mask_emb = tf.get_variable(
'mask', (embedding_dim,), dtype=tf.float32, trainable=True)
seq_len = tf.to_int32(seq_len)
aug_seq, aug_len = self.seq_augment(seq_input, seq_len, mask_emb,
self.seq_aug_params.crop_rate,
self.seq_aug_params.reorder_rate,
self.seq_aug_params.mask_rate)
with ops.device('/CPU:0'):
aug_seq, aug_len = self.seq_augment(seq_input, seq_len, mask_emb,
self.seq_aug_params.crop_rate,
self.seq_aug_params.reorder_rate,
self.seq_aug_params.mask_rate)
return aug_seq, aug_len


Expand All @@ -75,11 +76,13 @@ def __init__(self, params, name='text_normalize', reuse=None, **kwargs):

def call(self, inputs, training=None, **kwargs):
inputs = inputs if type(inputs) in (tuple, list) else [inputs]
result = [
self.txt_normalizer(
txt, parameter=self.norm_parameter, remove_space=self.remove_space)
for txt in inputs
]
with ops.device('/CPU:0'):
result = [
self.txt_normalizer(
txt,
parameter=self.norm_parameter,
remove_space=self.remove_space) for txt in inputs
]
if len(result) == 1:
return result[0]
return result
Expand Down
15 changes: 7 additions & 8 deletions easy_rec/python/layers/keras/din.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def __init__(self, params, name='din', reuse=None, **kwargs):
self.reuse = reuse
self.l2_reg = params.l2_regularizer
self.config = params.get_pb_config()
self.config.attention_dnn.use_final_bn = False
self.config.attention_dnn.use_final_bias = True
self.config.attention_dnn.final_activation = 'linear'
mlp_params = Parameter.make_from_pb(self.config.attention_dnn)
mlp_params.l2_regularizer = self.l2_reg
self.din_layer = MLP(mlp_params, 'din_attention', reuse=self.reuse)

def call(self, inputs, training=None, **kwargs):
keys, seq_len, query = inputs
Expand All @@ -36,14 +42,7 @@ def call(self, inputs, training=None, **kwargs):
queries = tf.tile(tf.expand_dims(query, 1), [1, max_seq_len, 1])
din_all = tf.concat([queries, keys, queries - keys, queries * keys],
axis=-1)

self.config.attention_dnn.use_final_bn = False
self.config.attention_dnn.use_final_bias = True
self.config.attention_dnn.final_activation = 'linear'
params = Parameter.make_from_pb(self.config.attention_dnn)
params.l2_regularizer = self.l2_reg
din_layer = MLP(params, name=self.name + '/din_attention', reuse=self.reuse)
output = din_layer(din_all, training) # [B, L, 1]
output = self.din_layer(din_all, training) # [B, L, 1]
scores = tf.transpose(output, [0, 2, 1]) # [B, 1, L]

seq_mask = tf.sequence_mask(seq_len, max_seq_len, dtype=tf.bool)
Expand Down
4 changes: 2 additions & 2 deletions easy_rec/python/layers/keras/mask_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def call(self, inputs, training=None, **kwargs):
]
all_mask_outputs = tf.concat(mask_outputs, axis=1)
if self.mlp is not None:
output = self.mlp(all_mask_outputs)
output = self.mlp(all_mask_outputs, training=training)
else:
output = all_mask_outputs
return output
Expand All @@ -160,7 +160,7 @@ def call(self, inputs, training=None, **kwargs):
net = mask_layer((net, inputs))

if self.mlp is not None:
output = self.mlp(net)
output = self.mlp(net, training=training)
else:
output = net
return output
Loading

0 comments on commit f403b69

Please sign in to comment.