diff --git a/docs/source/component/component.md b/docs/source/component/component.md index fa50ecd32..49a18662a 100644 --- a/docs/source/component/component.md +++ b/docs/source/component/component.md @@ -9,6 +9,7 @@ | Gate | 门控 | 多个输入的加权求和 | [Cross Decoupling Network](../models/cdn.html#id2) | | PeriodicEmbedding | 周期激活函数 | 数值特征Embedding | [案例5](backbone.md#dlrm-embedding) | | AutoDisEmbedding | 自动离散化 | 数值特征Embedding | [dlrm_on_criteo_with_autodis.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_on_criteo_with_autodis.config) | +| NaryDisEmbedding | N进制编码 | 数值特征Embedding | [dlrm_on_criteo_with_narydis.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_on_criteo_with_narydis.config) | | TextCNN | 文本卷积 | 提取文本序列的特征 | [text_cnn_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/text_cnn_on_movielens.config) | **备注**:Gate组件的第一个输入是权重向量,后面的输入拼凑成一个列表,权重向量的长度应等于列表的长度 @@ -47,9 +48,10 @@ ## 5. 多目标学习组件 -| 类名 | 功能 | 说明 | 示例 | -| ---- | --------------------------- | --------- | ----------------------- | -| MMoE | Multiple Mixture of Experts | MMoE模型的组件 | [案例8](backbone.md#mmoe) | +| 类名 | 功能 | 说明 | 示例 | +| --------- | --------------------------- | --------- | ----------------------------- | +| MMoE | Multiple Mixture of Experts | MMoE模型的组件 | [案例8](backbone.md#mmoe) | +| AITMTower | AITM模型的一个tower | AITM模型的组件 | [AITM](../models/aitm.md#id2) | ## 6. 辅助损失函数组件 @@ -108,6 +110,20 @@ | output_tensor_list | bool | false | 是否同时输出embedding列表 | | output_3d_tensor | bool | false | 是否同时输出3d tensor, `output_tensor_list=true`时该参数不生效 | +- NaryDisEmbedding + +| 参数 | 类型 | 默认值 | 说明 | +| ------------------ | ------ | ----- | --------------------------------------------------- | +| embedding_dim | uint32 | | embedding维度 | +| carries | list | | N-ary 数值特征需要编码的进制列表 | +| multiplier | float | 1.0 | 针对float类型的特征,放大`multiplier`倍再取整后进行进制编码 | +| intra_ary_pooling | string | sum | 同一进制的不同位的数字embedding如何聚合成最终的embedding, 可选:sum, mean | +| num_replicas | uint32 | 1 | 每个特征输出多少个embedding表征 | +| output_tensor_list | bool | false | 是否同时输出embedding列表 | +| output_3d_tensor | bool | false | 是否同时输出3d tensor, `output_tensor_list=true`时该参数不生效 | + +备注:该组件依赖自定义Tensorflow OP,可能在某些版本的TF上无法使用 + - TextCNN | 参数 | 类型 | 默认值 | 说明 | @@ -314,6 +330,14 @@ BERT模型结构 | num_expert | uint32 | 0 | expert数量 | | expert_mlp | MLP | 可选 | expert的mlp参数 | +- AITMTower + +| 参数 | 类型 | 默认值 | 说明 | +| ------------- | ------ | ---- | ------------------------------ | +| project_dim | uint32 | 可选 | attention Query, Key, Value的维度 | +| stop_gradient | bool | True | 是否需要停用对依赖的输入的梯度 | +| transfer_mlp | MLP | | transfer的mlp参数 | + ## 6. 计算辅助损失函数的组件 - AuxiliaryLoss diff --git a/docs/source/models/aitm.md b/docs/source/models/aitm.md index a15ea0489..6f4c57d7b 100644 --- a/docs/source/models/aitm.md +++ b/docs/source/models/aitm.md @@ -31,7 +31,7 @@ model_config { } backbone { blocks { - name: "mlp" + name: "share_bottom" inputs { feature_group_name: "all" } @@ -42,6 +42,49 @@ model_config { } } } + blocks { + name: "ctr_tower" + inputs { + block_name: "share_bottom" + } + keras_layer { + class_name: 'MLP' + mlp { + hidden_units: 128 + } + } + } + blocks { + name: "cvr_tower" + inputs { + block_name: "share_bottom" + } + keras_layer { + class_name: 'MLP' + mlp { + hidden_units: 128 + } + } + } + blocks { + name: "cvr_aitm" + inputs { + block_name: "cvr_tower" + } + inputs { + block_name: "ctr_tower" + } + merge_inputs_into_list: true + keras_layer { + class_name: "AITMTower" + aitm { + transfer_mlp { + hidden_units: 128 + } + } + } + } + output_blocks: ["ctr_tower", "cvr_aitm"] } model_params { task_towers { @@ -52,9 +95,8 @@ model_config { auc {} } dnn { - hidden_units: [256, 128] + hidden_units: 64 } - use_ait_module: true weight: 1.0 } task_towers { @@ -70,11 +112,8 @@ model_config { auc {} } dnn { - hidden_units: [256, 128] + hidden_units: 64 } - relation_tower_names: ["ctr"] - use_ait_module: true - ait_project_dim: 128 weight: 1.0 } l2_regularization: 1e-6 @@ -95,15 +134,15 @@ model_config { - name/inputs: 每个`block`有一个唯一的名字(name),并且有一个或多个输入(inputs)和输出 - keras_layer: 加载由`class_name`指定的自定义或系统内置的keras layer,执行一段代码逻辑;[参考文档](../component/backbone.md#keraslayer) - mlp: MLP模型的参数,详见[参考文档](../component/component.md#id1) + - cvr_aitm: AITMTower组件,该组件块的input的顺序不能乱写,第一个input必须是当前tower的输入,后续的inputs是依赖的前驱模块 + - output_blocks: backbone的输出tensor列表,顺序必须与下面`model_params`里配置的任务tower一致 -- model_params: AITM相关的参数 +- model_params: 多目标建模相关的参数 - task_towers 根据任务数配置task_towers - tower_name - dnn deep part的参数配置 - hidden_units: dnn每一层的channel数目,即神经元的数目 - - use_ait_module: if true 使用`AITM`模型;否则,使用[DBMTL](dbmtl.md)模型 - - ait_project_dim: 每个tower对应的表征向量的维度,一般设为最后一个隐藏的维度即可 - 默认为二分类任务,即num_class默认为1,weight默认为1.0,loss_type默认为CLASSIFICATION,metrics_set为auc - loss_type: ORDER_CALIBRATE_LOSS 使用目标依赖关系校正预测结果的辅助损失函数,详见原始论文 - 注:label_fields需与task_towers一一对齐。 diff --git a/easy_rec/python/layers/backbone.py b/easy_rec/python/layers/backbone.py index f3dab6391..e77ea1da5 100644 --- a/easy_rec/python/layers/backbone.py +++ b/easy_rec/python/layers/backbone.py @@ -266,8 +266,8 @@ def block_input(self, config, block_outputs, training=None, **kwargs): try: output = merge_inputs(inputs, config.input_concat_axis, config.name) except ValueError as e: - logging.error('merge inputs of block %s failed: %s', config.name, - e.message) + msg = getattr(e, 'message', str(e)) + logging.error('merge inputs of block %s failed: %s', config.name, msg) raise e if config.HasField('extra_input_fn'): @@ -317,7 +317,9 @@ def call(self, is_training, **kwargs): inputs, _, weights = self._feature_group_inputs[feature_group] block_outputs[block] = input_fn([inputs, weights], is_training) else: - inputs = self.block_input(config, block_outputs, is_training, **kwargs) + with tf.name_scope(block + '_input'): + inputs = self.block_input(config, block_outputs, is_training, + **kwargs) output = self.call_layer(inputs, config, block, is_training, **kwargs) block_outputs[block] = output @@ -340,7 +342,8 @@ def call(self, is_training, **kwargs): try: output = merge_inputs(outputs, msg='backbone') except ValueError as e: - logging.error("merge backbone's output failed: %s", e.message) + msg = getattr(e, 'message', str(e)) + logging.error("merge backbone's output failed: %s", msg) raise e return output @@ -402,8 +405,8 @@ def call_keras_layer(self, inputs, name, training, **kwargs): try: output = layer(inputs, training=training, **kwargs) except Exception as e: - logging.error('call keras layer %s (%s) failed: %s' % - (name, cls, e.message)) + msg = getattr(e, 'message', str(e)) + logging.error('call keras layer %s (%s) failed: %s' % (name, cls, msg)) raise e else: try: diff --git a/easy_rec/python/layers/keras/__init__.py b/easy_rec/python/layers/keras/__init__.py index ad0bf8528..c4427e5d3 100644 --- a/easy_rec/python/layers/keras/__init__.py +++ b/easy_rec/python/layers/keras/__init__.py @@ -22,6 +22,7 @@ from .mask_net import MaskBlock from .mask_net import MaskNet from .multi_head_attention import MultiHeadAttention +from .multi_task import AITMTower from .multi_task import MMoE from .numerical_embedding import AutoDisEmbedding from .numerical_embedding import NaryDisEmbedding diff --git a/easy_rec/python/layers/keras/activation.py b/easy_rec/python/layers/keras/activation.py index 532d5ca8f..fa6218e64 100644 --- a/easy_rec/python/layers/keras/activation.py +++ b/easy_rec/python/layers/keras/activation.py @@ -99,14 +99,14 @@ def call(self, inputs, mask=None): return tf.nn.softmax(inputs, axis=self.axis) -def activation_layer(activation): +def activation_layer(activation, name=None): if activation in ('dice', 'Dice'): - act_layer = Dice() + act_layer = Dice(name=name) elif isinstance(activation, (str, unicode)): act_fn = easy_rec.python.utils.activation.get_activation(activation) - act_layer = Activation(act_fn) + act_layer = Activation(act_fn, name=name) elif issubclass(activation, Layer): - act_layer = activation() + act_layer = activation(name=name) else: raise ValueError( 'Invalid activation,found %s.You should use a str or a Activation Layer Class.' diff --git a/easy_rec/python/layers/keras/blocks.py b/easy_rec/python/layers/keras/blocks.py index 5c411922c..c9e722a67 100644 --- a/easy_rec/python/layers/keras/blocks.py +++ b/easy_rec/python/layers/keras/blocks.py @@ -32,7 +32,7 @@ class MLP(Layer): def __init__(self, params, name='mlp', reuse=None, **kwargs): super(MLP, self).__init__(name=name, **kwargs) - self.layer_name = name + self.layer_name = name # for add to output params.check_required('hidden_units') use_bn = params.get_or_default('use_bn', True) use_final_bn = params.get_or_default('use_final_bn', True) @@ -79,14 +79,14 @@ def add_rich_layer(self, use_bn_after_activation, name, l2_reg=None): - act_layer = activation_layer(activation) + act_layer = activation_layer(activation, name='%s/act' % name) if use_bn and not use_bn_after_activation: dense = Dense( units=num_units, use_bias=use_bias, kernel_initializer=initializer, kernel_regularizer=l2_reg, - name=name) + name='%s/dense' % name) self._sub_layers.append(dense) bn = tf.keras.layers.BatchNormalization( name='%s/bn' % name, trainable=True) @@ -98,7 +98,7 @@ def add_rich_layer(self, use_bias=use_bias, kernel_initializer=initializer, kernel_regularizer=l2_reg, - name=name) + name='%s/dense' % name) self._sub_layers.append(dense) self._sub_layers.append(act_layer) if use_bn and use_bn_after_activation: @@ -117,7 +117,7 @@ def call(self, x, training=None, **kwargs): cls = layer.__class__.__name__ if cls in ('Dropout', 'BatchNormalization', 'Dice'): x = layer(x, training=training) - if cls in ('BatchNormalization', 'Dice'): + if cls in ('BatchNormalization', 'Dice') and training: add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS) else: x = layer(x) @@ -183,8 +183,14 @@ class Gate(Layer): def __init__(self, params, name='gate', reuse=None, **kwargs): super(Gate, self).__init__(name=name, **kwargs) self.weight_index = params.get_or_default('weight_index', 0) + if params.has_field('mlp'): + mlp_cfg = Parameter.make_from_pb(params.mlp) + mlp_cfg.l2_regularizer = params.l2_regularizer + self.top_mlp = MLP(mlp_cfg, name='top_mlp') + else: + self.top_mlp = None - def call(self, inputs, **kwargs): + def call(self, inputs, training=None, **kwargs): assert len( inputs ) > 1, 'input of Gate layer must be a list containing at least 2 elements' @@ -198,6 +204,8 @@ def call(self, inputs, **kwargs): else: output += weights[:, j, None] * x j += 1 + if self.top_mlp is not None: + output = self.top_mlp(output, training=training) return output @@ -248,7 +256,7 @@ def call(self, inputs, training=None, **kwargs): pooled_outputs.append(pooled) net = self.concat_layer(pooled_outputs) if self.mlp is not None: - output = self.mlp(net) + output = self.mlp(net, training=training) else: output = net return output diff --git a/easy_rec/python/layers/keras/custom_ops.py b/easy_rec/python/layers/keras/custom_ops.py index f0b04b2ab..c215ee332 100644 --- a/easy_rec/python/layers/keras/custom_ops.py +++ b/easy_rec/python/layers/keras/custom_ops.py @@ -58,10 +58,11 @@ def call(self, inputs, training=None, **kwargs): mask_emb = tf.get_variable( 'mask', (embedding_dim,), dtype=tf.float32, trainable=True) seq_len = tf.to_int32(seq_len) - aug_seq, aug_len = self.seq_augment(seq_input, seq_len, mask_emb, - self.seq_aug_params.crop_rate, - self.seq_aug_params.reorder_rate, - self.seq_aug_params.mask_rate) + with ops.device('/CPU:0'): + aug_seq, aug_len = self.seq_augment(seq_input, seq_len, mask_emb, + self.seq_aug_params.crop_rate, + self.seq_aug_params.reorder_rate, + self.seq_aug_params.mask_rate) return aug_seq, aug_len @@ -75,11 +76,13 @@ def __init__(self, params, name='text_normalize', reuse=None, **kwargs): def call(self, inputs, training=None, **kwargs): inputs = inputs if type(inputs) in (tuple, list) else [inputs] - result = [ - self.txt_normalizer( - txt, parameter=self.norm_parameter, remove_space=self.remove_space) - for txt in inputs - ] + with ops.device('/CPU:0'): + result = [ + self.txt_normalizer( + txt, + parameter=self.norm_parameter, + remove_space=self.remove_space) for txt in inputs + ] if len(result) == 1: return result[0] return result diff --git a/easy_rec/python/layers/keras/din.py b/easy_rec/python/layers/keras/din.py index c30ecae90..082677e0b 100644 --- a/easy_rec/python/layers/keras/din.py +++ b/easy_rec/python/layers/keras/din.py @@ -17,6 +17,12 @@ def __init__(self, params, name='din', reuse=None, **kwargs): self.reuse = reuse self.l2_reg = params.l2_regularizer self.config = params.get_pb_config() + self.config.attention_dnn.use_final_bn = False + self.config.attention_dnn.use_final_bias = True + self.config.attention_dnn.final_activation = 'linear' + mlp_params = Parameter.make_from_pb(self.config.attention_dnn) + mlp_params.l2_regularizer = self.l2_reg + self.din_layer = MLP(mlp_params, 'din_attention', reuse=self.reuse) def call(self, inputs, training=None, **kwargs): keys, seq_len, query = inputs @@ -36,14 +42,7 @@ def call(self, inputs, training=None, **kwargs): queries = tf.tile(tf.expand_dims(query, 1), [1, max_seq_len, 1]) din_all = tf.concat([queries, keys, queries - keys, queries * keys], axis=-1) - - self.config.attention_dnn.use_final_bn = False - self.config.attention_dnn.use_final_bias = True - self.config.attention_dnn.final_activation = 'linear' - params = Parameter.make_from_pb(self.config.attention_dnn) - params.l2_regularizer = self.l2_reg - din_layer = MLP(params, name=self.name + '/din_attention', reuse=self.reuse) - output = din_layer(din_all, training) # [B, L, 1] + output = self.din_layer(din_all, training) # [B, L, 1] scores = tf.transpose(output, [0, 2, 1]) # [B, 1, L] seq_mask = tf.sequence_mask(seq_len, max_seq_len, dtype=tf.bool) diff --git a/easy_rec/python/layers/keras/mask_net.py b/easy_rec/python/layers/keras/mask_net.py index 67176851c..bf687154e 100644 --- a/easy_rec/python/layers/keras/mask_net.py +++ b/easy_rec/python/layers/keras/mask_net.py @@ -149,7 +149,7 @@ def call(self, inputs, training=None, **kwargs): ] all_mask_outputs = tf.concat(mask_outputs, axis=1) if self.mlp is not None: - output = self.mlp(all_mask_outputs) + output = self.mlp(all_mask_outputs, training=training) else: output = all_mask_outputs return output @@ -160,7 +160,7 @@ def call(self, inputs, training=None, **kwargs): net = mask_layer((net, inputs)) if self.mlp is not None: - output = self.mlp(net) + output = self.mlp(net, training=training) else: output = net return output diff --git a/easy_rec/python/layers/keras/multi_task.py b/easy_rec/python/layers/keras/multi_task.py index d092fd4d8..dbb26ee86 100644 --- a/easy_rec/python/layers/keras/multi_task.py +++ b/easy_rec/python/layers/keras/multi_task.py @@ -3,24 +3,19 @@ import logging import tensorflow as tf +from tensorflow.python.keras.layers import Dense +from tensorflow.python.keras.layers import Layer +from easy_rec.python.layers.keras.attention import Attention from easy_rec.python.layers.keras.blocks import MLP +from easy_rec.python.layers.utils import Parameter +from easy_rec.python.protos import seq_encoder_pb2 if tf.__version__ >= '2.0': tf = tf.compat.v1 -def gate_fn(inputs, units, name, l2_reg, reuse): - weights = tf.layers.dense( - inputs, - units, - kernel_regularizer=l2_reg, - name='%s/dense' % name, - reuse=reuse) - return tf.nn.softmax(weights, axis=1) - - -class MMoE(tf.keras.layers.Layer): +class MMoE(Layer): """Multi-gate Mixture-of-Experts model.""" def __init__(self, params, name='MMoE', reuse=None, **kwargs): @@ -30,7 +25,8 @@ def __init__(self, params, name='MMoE', reuse=None, **kwargs): self._num_expert = params.num_expert self._num_task = params.num_task if params.has_field('expert_mlp'): - expert_params = params.expert_mlp + expert_params = Parameter.make_from_pb(params.expert_mlp) + expert_params.l2_regularizer = params.l2_regularizer self._has_experts = True self._experts = [ MLP(expert_params, 'expert_%d' % i, reuse=reuse) @@ -38,29 +34,92 @@ def __init__(self, params, name='MMoE', reuse=None, **kwargs): ] else: self._has_experts = False - self._experts = [lambda x: x[i] for i in range(self._num_expert)] - self._l2_reg = params.l2_regularizer - def __call__(self, inputs, **kwargs): + self._gates = [] + for task_id in range(self._num_task): + dense = Dense( + self._num_expert, + activation='softmax', + name='gate_%d' % task_id, + kernel_regularizer=params.l2_regularizer) + self._gates.append(dense) + + def call(self, inputs, training=None, **kwargs): if self._num_expert == 0: logging.warning('num_expert of MMoE layer `%s` is 0' % self.name) return inputs + if self._has_experts: + expert_fea_list = [ + expert(inputs, training=training) for expert in self._experts + ] + else: + expert_fea_list = inputs + experts_fea = tf.stack(expert_fea_list, axis=1) + # 不使用内置MLP作为expert时,gate的input使用最后一个额外的输入 + gate_input = inputs if self._has_experts else inputs[self._num_expert] + task_input_list = [] + for task_id in range(self._num_task): + gate = self._gates[task_id](gate_input) + gate = tf.expand_dims(gate, -1) + task_input = tf.multiply(experts_fea, gate) + task_input = tf.reduce_sum(task_input, axis=1) + task_input_list.append(task_input) + return task_input_list - with tf.name_scope(self.name): - expert_fea_list = [expert(inputs) for expert in self._experts] - experts_fea = tf.stack(expert_fea_list, axis=1) - gate_input = inputs if self._has_experts else inputs[self._num_expert] - task_input_list = [] - for task_id in range(self._num_task): - gate = gate_fn( - gate_input, - self._num_expert, - name='gate_%d' % task_id, - l2_reg=self._l2_reg, - reuse=self._reuse) - gate = tf.expand_dims(gate, -1) - task_input = tf.multiply(experts_fea, gate) - task_input = tf.reduce_sum(task_input, axis=1) - task_input_list.append(task_input) - return task_input_list +class AITMTower(Layer): + """Adaptive Information Transfer Multi-task (AITM) Tower.""" + + def __init__(self, params, name='AITMTower', reuse=None, **kwargs): + super(AITMTower, self).__init__(name=name, **kwargs) + self.project_dim = params.get_or_default('project_dim', None) + self.stop_gradient = params.get_or_default('stop_gradient', True) + self.transfer = None + if params.has_field('transfer_mlp'): + mlp_cfg = Parameter.make_from_pb(params.transfer_mlp) + mlp_cfg.l2_regularizer = params.l2_regularizer + self.transfer = MLP(mlp_cfg, name='transfer') + self.queries = [] + self.keys = [] + self.values = [] + self.attention = None + + def build(self, input_shape): + if not isinstance(input_shape, (tuple, list)): + super(AITMTower, self).build(input_shape) + return + dim = self.project_dim if self.project_dim else int(input_shape[0][-1]) + for i in range(len(input_shape)): + self.queries.append(Dense(dim, name='query_%d' % i)) + self.keys.append(Dense(dim, name='key_%d' % i)) + self.values.append(Dense(dim, name='value_%d' % i)) + attn_cfg = seq_encoder_pb2.Attention() + attn_cfg.scale_by_dim = True + attn_params = Parameter.make_from_pb(attn_cfg) + self.attention = Attention(attn_params) + super(AITMTower, self).build(input_shape) + + def call(self, inputs, training=None, **kwargs): + if not isinstance(inputs, (tuple, list)): + return inputs + + queries = [] + keys = [] + values = [] + for i, tower in enumerate(inputs): + if i == 0: # current tower + queries.append(self.queries[i](tower)) + keys.append(self.keys[i](tower)) + values.append(self.values[i](tower)) + else: + dep = tf.stop_gradient(tower) if self.stop_gradient else tower + if self.transfer is not None: + dep = self.transfer(dep, training=training) + queries.append(self.queries[i](dep)) + keys.append(self.keys[i](dep)) + values.append(self.values[i](dep)) + query = tf.stack(queries, axis=1) + key = tf.stack(keys, axis=1) + value = tf.stack(values, axis=1) + attn = self.attention([query, value, key]) + return attn[:, 0, :] diff --git a/easy_rec/python/layers/keras/numerical_embedding.py b/easy_rec/python/layers/keras/numerical_embedding.py index 1b37bcf83..65cc77d52 100644 --- a/easy_rec/python/layers/keras/numerical_embedding.py +++ b/easy_rec/python/layers/keras/numerical_embedding.py @@ -5,10 +5,12 @@ import os import tensorflow as tf +from tensorflow.python.framework import ops from tensorflow.python.keras.layers import Layer from easy_rec.python.compat.array_ops import repeat from easy_rec.python.utils.activation import get_activation +from easy_rec.python.utils.tf_utils import get_ps_num_from_tf_config curr_dir, _ = os.path.split(__file__) parent_dir = os.path.dirname(curr_dir) @@ -36,11 +38,8 @@ (custom_op_path, str(ex))) custom_ops = None -if tf.__version__ >= '2.0': - tf = tf.compat.v1 - -class NLinear(object): +class NLinear(Layer): """N linear layers for N token (feature) embeddings. To understand this module, let's revise `tf.layers.dense`. When `tf.layers.dense` is @@ -74,8 +73,8 @@ def __init__(self, d_in, d_out, bias=True, - scope='nd_linear', - reuse=None): + name='nd_linear', + **kwargs): """Init with input shapes. Args: @@ -83,22 +82,21 @@ def __init__(self, d_in: the input dimension d_out: the output dimension bias: indicates if the underlying linear layers have biases - scope: variable scope name - reuse: whether to reuse variables + name: layer name """ - with tf.variable_scope(scope, reuse=reuse): - self.weight = tf.get_variable( - 'weights', [1, n_tokens, d_in, d_out], dtype=tf.float32) - if bias: - initializer = tf.constant_initializer(0.0) - self.bias = tf.get_variable( - 'bias', [1, n_tokens, d_out], - dtype=tf.float32, - initializer=initializer) - else: - self.bias = None + super(NLinear, self).__init__(name=name, **kwargs) + self.weight = self.add_weight( + 'weights', [1, n_tokens, d_in, d_out], dtype=tf.float32) + if bias: + initializer = tf.constant_initializer(0.0) + self.bias = self.add_weight( + 'bias', [1, n_tokens, d_out], + dtype=tf.float32, + initializer=initializer) + else: + self.bias = None - def __call__(self, x, *args, **kwargs): + def call(self, x, **kwargs): if x.shape.ndims != 3: raise ValueError( 'The input must have three dimensions (batch_size, n_tokens, d_embedding)' @@ -150,41 +148,46 @@ def __init__(self, params, name='periodic_embedding', reuse=None, **kwargs): self.output_tensor_list = params.get_or_default('output_tensor_list', False) self.output_3d_tensor = params.get_or_default('output_3d_tensor', False) - def call(self, inputs, **kwargs): - if inputs.shape.ndims != 2: - raise ValueError('inputs of PeriodicEmbedding must have 2 dimensions.') - - num_features = int(inputs.shape[-1]) + def build(self, input_shape): + if input_shape.ndims != 2: + raise ValueError('inputs of AutoDisEmbedding must have 2 dimensions.') + self.num_features = int(input_shape[-1]) + num_ps = get_ps_num_from_tf_config() + partitioner = None + if num_ps > 0: + partitioner = tf.fixed_size_partitioner(num_shards=num_ps) emb_dim = self.embedding_dim // 2 - with tf.variable_scope(self.name, reuse=self.reuse): - c = tf.get_variable( - 'coefficients', - shape=[1, num_features, emb_dim], - initializer=self.initializer) - - features = inputs[..., None] # [B, N, 1] - v = 2 * math.pi * c * features # [B, N, E] - emb = tf.concat([tf.sin(v), tf.cos(v)], axis=-1) # [B, N, 2E] - - dim = self.embedding_dim - if self.add_linear_layer: - linear = NLinear( - num_features, - dim, - dim, - scope='%s_nd_linear' % self.name, - reuse=self.reuse) - emb = linear(emb) - act = get_activation(self.linear_activation) - if callable(act): - emb = act(emb) - output = tf.reshape(emb, [-1, num_features * dim]) + self.coef = self.add_weight( + 'coefficients', + shape=[1, self.num_features, emb_dim], + partitioner=partitioner, + initializer=self.initializer) + if self.add_linear_layer: + self.linear = NLinear( + self.num_features, + self.embedding_dim, + self.embedding_dim, + name='nd_linear') + super(PeriodicEmbedding, self).build(input_shape) - if self.output_tensor_list: - return output, tf.unstack(emb, axis=1) - if self.output_3d_tensor: - return output, emb - return output + def call(self, inputs, **kwargs): + features = inputs[..., None] # [B, N, 1] + v = 2 * math.pi * self.coef * features # [B, N, E] + emb = tf.concat([tf.sin(v), tf.cos(v)], axis=-1) # [B, N, 2E] + + dim = self.embedding_dim + if self.add_linear_layer: + emb = self.linear(emb) + act = get_activation(self.linear_activation) + if callable(act): + emb = act(emb) + output = tf.reshape(emb, [-1, self.num_features * dim]) + + if self.output_tensor_list: + return output, tf.unstack(emb, axis=1) + if self.output_3d_tensor: + return output, emb + return output class AutoDisEmbedding(Layer): @@ -204,42 +207,51 @@ def __init__(self, params, name='auto_dis_embedding', reuse=None, **kwargs): self.output_tensor_list = params.get_or_default('output_tensor_list', False) self.output_3d_tensor = params.get_or_default('output_3d_tensor', False) - def call(self, inputs, **kwargs): - if inputs.shape.ndims != 2: + def build(self, input_shape): + if input_shape.ndims != 2: raise ValueError('inputs of AutoDisEmbedding must have 2 dimensions.') + self.num_features = int(input_shape[-1]) + num_ps = get_ps_num_from_tf_config() + partitioner = None + if num_ps > 0: + partitioner = tf.fixed_size_partitioner(num_shards=num_ps) + self.meta_emb = self.add_weight( + 'meta_embedding', + shape=[self.num_features, self.num_bins, self.emb_dim], + partitioner=partitioner) + self.proj_w = self.add_weight( + 'project_w', + shape=[1, self.num_features, self.num_bins], + partitioner=partitioner) + self.proj_mat = self.add_weight( + 'project_mat', + shape=[self.num_features, self.num_bins, self.num_bins], + partitioner=partitioner) + super(AutoDisEmbedding, self).build(input_shape) - num_features = int(inputs.shape[-1]) - with tf.variable_scope(self.name, reuse=self.reuse): - meta_emb = tf.get_variable( - 'meta_embedding', shape=[num_features, self.num_bins, self.emb_dim]) - w = tf.get_variable('project_w', shape=[1, num_features, self.num_bins]) - mat = tf.get_variable( - 'project_mat', shape=[num_features, self.num_bins, self.num_bins]) - - x = tf.expand_dims(inputs, axis=-1) # [B, N, 1] - hidden = tf.nn.leaky_relu(w * x) # [B, N, num_bin] - # 低版本的tf(1.12) matmul 不支持广播,所以改成 einsum - # y = tf.matmul(mat, hidden[..., None]) # [B, N, num_bin, 1] - # y = tf.squeeze(y, axis=3) # [B, N, num_bin] - y = tf.einsum('nik,bnk->bni', mat, hidden) # [B, N, num_bin] - - # keep_prob(float): if dropout_flag is True, keep_prob rate to keep connect - alpha = self.keep_prob - x_bar = y + alpha * hidden # [B, N, num_bin] - x_hat = tf.nn.softmax(x_bar / self.temperature) # [B, N, num_bin] - - # emb = tf.matmul(x_hat[:, :, None, :], meta_emb) # [B, N, 1, D] - # emb = tf.squeeze(emb, axis=2) # [B, N, D] - emb = tf.einsum('bnk,nkd->bnd', x_hat, meta_emb) - - output = tf.reshape(emb, [-1, self.emb_dim * num_features]) # [B, N*D] - - if self.output_tensor_list: - return output, tf.unstack(emb, axis=1) - - if self.output_3d_tensor: - return output, emb - return output + def call(self, inputs, **kwargs): + x = tf.expand_dims(inputs, axis=-1) # [B, N, 1] + hidden = tf.nn.leaky_relu(self.proj_w * x) # [B, N, num_bin] + # 低版本的tf(1.12) matmul 不支持广播,所以改成 einsum + # y = tf.matmul(mat, hidden[..., None]) # [B, N, num_bin, 1] + # y = tf.squeeze(y, axis=3) # [B, N, num_bin] + y = tf.einsum('nik,bnk->bni', self.proj_mat, hidden) # [B, N, num_bin] + + # keep_prob(float): if dropout_flag is True, keep_prob rate to keep connect + alpha = self.keep_prob + x_bar = y + alpha * hidden # [B, N, num_bin] + x_hat = tf.nn.softmax(x_bar / self.temperature) # [B, N, num_bin] + + # emb = tf.matmul(x_hat[:, :, None, :], meta_emb) # [B, N, 1, D] + # emb = tf.squeeze(emb, axis=2) # [B, N, D] + emb = tf.einsum('bnk,nkd->bnd', x_hat, self.meta_emb) + output = tf.reshape(emb, [-1, self.emb_dim * self.num_features]) # [B, N*D] + + if self.output_tensor_list: + return output, tf.unstack(emb, axis=1) + if self.output_3d_tensor: + return output, emb + return output class NaryDisEmbedding(Layer): @@ -255,19 +267,19 @@ def __init__(self, params, name='nary_dis_embedding', reuse=None, **kwargs): params.check_required(['embedding_dim', 'carries']) self.emb_dim = int(params.embedding_dim) self.carries = params.get_or_default('carries', [2, 9]) + self.num_replicas = params.get_or_default('num_replicas', 1) + assert self.num_replicas >= 1, 'num replicas must be >= 1' self.lengths = list(map(self.max_length, self.carries)) self.vocab_size = int(sum(self.lengths)) self.multiplier = params.get_or_default('multiplier', 1.0) self.intra_ary_pooling = params.get_or_default('intra_ary_pooling', 'sum') - self.inter_ary_pooling = params.get_or_default('inter_ary_pooling', - 'concat') self.output_3d_tensor = params.get_or_default('output_3d_tensor', False) + self.output_tensor_list = params.get_or_default('output_tensor_list', False) logging.info( - '{} carries: {}, lengths: {}, vocab_size: {}, intra_ary: {}, inter_ary: {}, multiplier: {}' + '{} carries: {}, lengths: {}, vocab_size: {}, intra_ary: {}, replicas: {}, multiplier: {}' .format(self.name, ','.join(map(str, self.carries)), ','.join(map(str, self.lengths)), self.vocab_size, - self.intra_ary_pooling, self.inter_ary_pooling, - self.multiplier)) + self.intra_ary_pooling, self.num_replicas, self.multiplier)) @staticmethod def max_length(carry): @@ -280,12 +292,13 @@ def build(self, input_shape): self.num_features = int(input_shape[-1]) logging.info('%s has %d input features', self.name, self.num_features) vocab_size = self.num_features * self.vocab_size + emb_dim = self.emb_dim * self.num_replicas + num_ps = get_ps_num_from_tf_config() + partitioner = None + if num_ps > 0: + partitioner = tf.fixed_size_partitioner(num_shards=num_ps) self.embedding_table = self.add_weight( - 'embed_table', - shape=[vocab_size, self.emb_dim], - initializer='he_uniform', - dtype=tf.float32, - trainable=True) + 'embed_table', shape=[vocab_size, emb_dim], partitioner=partitioner) super(NaryDisEmbedding, self).build(input_shape) def call(self, inputs, **kwargs): @@ -295,11 +308,12 @@ def call(self, inputs, **kwargs): inputs *= self.multiplier inputs = tf.to_int32(inputs) offset, emb_indices, emb_splits = 0, [], [] - for carry, length in zip(self.carries, self.lengths): - values, splits = self.nary_carry(inputs, carry=carry, offset=offset) - offset += length - emb_indices.append(values) - emb_splits.append(splits) + with ops.device('/CPU:0'): + for carry, length in zip(self.carries, self.lengths): + values, splits = self.nary_carry(inputs, carry=carry, offset=offset) + offset += length + emb_indices.append(values) + emb_splits.append(splits) indices = tf.concat(emb_indices, axis=0) splits = tf.concat(emb_splits, axis=0) # embedding shape: [B*N*C, D] @@ -307,28 +321,56 @@ def call(self, inputs, **kwargs): total_length = tf.size(splits) if self.intra_ary_pooling == 'sum': - segment_ids = repeat(tf.range(total_length), repeats=splits) + if tf.__version__ >= '2.0': + segment_ids = tf.repeat(tf.range(total_length), repeats=splits) + else: + segment_ids = repeat(tf.range(total_length), repeats=splits) embedding = tf.math.segment_sum(embedding, segment_ids) elif self.intra_ary_pooling == 'mean': - segment_ids = repeat(tf.range(total_length), repeats=splits) + if tf.__version__ >= '2.0': + segment_ids = tf.repeat(tf.range(total_length), repeats=splits) + else: + segment_ids = repeat(tf.range(total_length), repeats=splits) embedding = tf.math.segment_mean(embedding, segment_ids) else: raise ValueError('Unsupported intra ary pooling method %s' % self.intra_ary_pooling) + # B: batch size + # N: num features + # C: num carries + # D: embedding dimension + # R: num replicas + # shape of embedding: [B*N*C, R*D] + N = self.num_features + C = len(self.carries) + D = self.emb_dim + if self.num_replicas == 1: + embedding = tf.reshape(embedding, [C, -1, D]) # [C, B*N, D] + embedding = tf.transpose(embedding, perm=[1, 0, 2]) # [B*N, C, D] + embedding = tf.reshape(embedding, [-1, C * D]) # [B*N, C*D] + output = tf.reshape(embedding, [-1, N * C * D]) # [B, N*C*D] + if self.output_tensor_list: + return output, tf.split(embedding, N) # [B, C*D] * N + if self.output_3d_tensor: + embedding = tf.reshape(embedding, [-1, N, C * D]) # [B, N, C*D] + return output, embedding + return output - if self.inter_ary_pooling == 'concat': - embeddings = tf.split(embedding, len(self.carries)) - embedding = tf.concat(embeddings, axis=-1) # [B*N, C*D] - else: - raise ValueError('Unsupported inter ary pooling method %s' % - self.inter_ary_pooling) - if self.output_3d_tensor: - embedding = tf.reshape( - embedding, [-1, self.num_features, - len(self.carries) * self.emb_dim]) # [B, N, C*D] - else: - embedding = tf.reshape( - embedding, [-1, self.num_features * len(self.carries) * self.emb_dim - ]) # [B, N*C*D] - print('NaryDisEmbedding:', embedding) - return embedding + # self.num_replicas > 1: + replicas = tf.split(embedding, self.num_replicas, axis=1) + outputs = [] + outputs2 = [] + for replica in replicas: + # shape of replica: [B*N*C, D] + embedding = tf.reshape(replica, [C, -1, D]) # [C, B*N, D] + embedding = tf.transpose(embedding, perm=[1, 0, 2]) # [B*N, C, D] + embedding = tf.reshape(embedding, [-1, C * D]) # [B*N, C*D] + output = tf.reshape(embedding, [-1, N * C * D]) # [B, N*C*D] + outputs.append(output) + if self.output_tensor_list: + embedding = tf.split(embedding, N) # [B, C*D] * N + outputs2.append(embedding) + elif self.output_3d_tensor: + embedding = tf.reshape(embedding, [-1, N, C * D]) # [B, N, C*D] + outputs2.append(embedding) + return outputs + outputs2 diff --git a/easy_rec/python/layers/keras/ppnet.py b/easy_rec/python/layers/keras/ppnet.py index 71f5902d1..431034924 100644 --- a/easy_rec/python/layers/keras/ppnet.py +++ b/easy_rec/python/layers/keras/ppnet.py @@ -32,20 +32,18 @@ def __init__(self, dense = tf.keras.layers.Dense( units=hidden_dim, use_bias=not do_batch_norm, - kernel_initializer=initializer, - name=name) + kernel_initializer=initializer) self._sub_layers.append(dense) if do_batch_norm: - bn = tf.keras.layers.BatchNormalization( - name='%s/bn' % name, trainable=True) + bn = tf.keras.layers.BatchNormalization(trainable=True) self._sub_layers.append(bn) act_layer = activation_layer(activation) self._sub_layers.append(act_layer) if 0.0 < dropout_rate < 1.0: - dropout = tf.keras.layers.Dropout(dropout_rate, name='%s/dropout' % name) + dropout = tf.keras.layers.Dropout(dropout_rate) self._sub_layers.append(dropout) elif dropout_rate >= 1.0: raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate) @@ -55,7 +53,7 @@ def __init__(self, activation='sigmoid', use_bias=not do_batch_norm, kernel_initializer=initializer, - name=name) + name='weight') self._sub_layers.append(dense) self._sub_layers.append(lambda x: x * 2) @@ -65,7 +63,7 @@ def call(self, x, training=None, **kwargs): cls = layer.__class__.__name__ if cls in ('Dropout', 'BatchNormalization', 'Dice'): x = layer(x, training=training) - if cls in ('BatchNormalization', 'Dice'): + if cls in ('BatchNormalization', 'Dice') and training: add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS) else: x = layer(x) @@ -143,14 +141,14 @@ def add_rich_layer(self, use_bn_after_activation, name, l2_reg=None): - act_layer = activation_layer(activation) + act_layer = activation_layer(activation, name='%s/act' % name) if use_bn and not use_bn_after_activation: dense = tf.keras.layers.Dense( units=num_units, use_bias=use_bias, kernel_initializer=initializer, kernel_regularizer=l2_reg, - name=name) + name='%s/dense' % name) self._sub_layers.append(dense) bn = tf.keras.layers.BatchNormalization( name='%s/bn' % name, trainable=True) @@ -162,7 +160,7 @@ def add_rich_layer(self, use_bias=use_bias, kernel_initializer=initializer, kernel_regularizer=l2_reg, - name=name) + name='%s/dense' % name) self._sub_layers.append(dense) self._sub_layers.append(act_layer) if use_bn and use_bn_after_activation: @@ -189,7 +187,7 @@ def call(self, inputs, training=None, **kwargs): x *= gate elif cls in ('Dropout', 'BatchNormalization', 'Dice'): x = layer(x, training=training) - if cls in ('BatchNormalization', 'Dice'): + if cls in ('BatchNormalization', 'Dice') and training: add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS) else: x = layer(x) diff --git a/easy_rec/python/model/easy_rec_estimator.py b/easy_rec/python/model/easy_rec_estimator.py index c40260b18..95385936c 100644 --- a/easy_rec/python/model/easy_rec_estimator.py +++ b/easy_rec/python/model/easy_rec_estimator.py @@ -153,6 +153,7 @@ def saver_cls(self): return tmp_saver_cls def _train_model_fn(self, features, labels, run_config): + tf.keras.backend.set_learning_phase(1) model = self._model_cls( self.model_config, self.feature_configs, @@ -471,6 +472,7 @@ def _train_model_fn(self, features, labels, run_config): training_hooks=hooks) def _eval_model_fn(self, features, labels, run_config): + tf.keras.backend.set_learning_phase(0) start = time.time() model = self._model_cls( self.model_config, @@ -510,6 +512,7 @@ def _eval_model_fn(self, features, labels, run_config): eval_metric_ops=metric_dict) def _distribute_eval_model_fn(self, features, labels, run_config): + tf.keras.backend.set_learning_phase(0) start = time.time() model = self._model_cls( self.model_config, @@ -562,6 +565,7 @@ def _distribute_eval_model_fn(self, features, labels, run_config): scaffold=scaffold) def _export_model_fn(self, features, labels, run_config, params): + tf.keras.backend.set_learning_phase(0) model = self._model_cls( self.model_config, self.feature_configs, diff --git a/easy_rec/python/model/multi_task_model.py b/easy_rec/python/model/multi_task_model.py index d792d3131..aa102104c 100644 --- a/easy_rec/python/model/multi_task_model.py +++ b/easy_rec/python/model/multi_task_model.py @@ -4,14 +4,10 @@ from collections import OrderedDict import tensorflow as tf -from tensorflow.python.keras.layers import Dense from easy_rec.python.builders import loss_builder from easy_rec.python.layers.dnn import DNN -from easy_rec.python.layers.keras.attention import Attention -from easy_rec.python.layers.utils import Parameter from easy_rec.python.model.rank_model import RankModel -from easy_rec.python.protos import seq_encoder_pb2 from easy_rec.python.protos import tower_pb2 from easy_rec.python.protos.easy_rec_model_pb2 import EasyRecModel from easy_rec.python.protos.loss_pb2 import LossType @@ -59,65 +55,45 @@ def build_predict_graph(self): tower_features = {} for i, task_tower_cfg in enumerate(config.task_towers): tower_name = task_tower_cfg.tower_name - if task_tower_cfg.HasField('dnn'): - tower_dnn = DNN( - task_tower_cfg.dnn, - self._l2_reg, - name=tower_name, - is_training=self._is_training) - tower_output = tower_dnn(task_input_list[i]) - else: - tower_output = task_input_list[i] - tower_features[tower_name] = tower_output + with tf.name_scope(tower_name): + if task_tower_cfg.HasField('dnn'): + tower_dnn = DNN( + task_tower_cfg.dnn, + self._l2_reg, + name=tower_name, + is_training=self._is_training) + tower_output = tower_dnn(task_input_list[i]) + else: + tower_output = task_input_list[i] + tower_features[tower_name] = tower_output tower_outputs = {} relation_features = {} # bayes network for task_tower_cfg in config.task_towers: tower_name = task_tower_cfg.tower_name - if task_tower_cfg.HasField('relation_dnn'): - relation_dnn = DNN( - task_tower_cfg.relation_dnn, - self._l2_reg, - name=tower_name + '/relation_dnn', - is_training=self._is_training) - tower_inputs = [tower_features[tower_name]] - for relation_tower_name in task_tower_cfg.relation_tower_names: - tower_inputs.append(relation_features[relation_tower_name]) - relation_input = tf.concat( - tower_inputs, axis=-1, name=tower_name + '/relation_input') - relation_fea = relation_dnn(relation_input) - relation_features[tower_name] = relation_fea - elif task_tower_cfg.use_ait_module: - tower_inputs = [tower_features[tower_name]] - for relation_tower_name in task_tower_cfg.relation_tower_names: - tower_inputs.append(relation_features[relation_tower_name]) - if len(tower_inputs) == 1: - relation_fea = tower_inputs[0] + with tf.name_scope(tower_name): + if task_tower_cfg.HasField('relation_dnn'): + relation_dnn = DNN( + task_tower_cfg.relation_dnn, + self._l2_reg, + name=tower_name + '/relation_dnn', + is_training=self._is_training) + tower_inputs = [tower_features[tower_name]] + for relation_tower_name in task_tower_cfg.relation_tower_names: + tower_inputs.append(relation_features[relation_tower_name]) + relation_input = tf.concat( + tower_inputs, axis=-1, name=tower_name + '/relation_input') + relation_fea = relation_dnn(relation_input) relation_features[tower_name] = relation_fea else: - if task_tower_cfg.HasField('ait_project_dim'): - dim = task_tower_cfg.ait_project_dim - else: - dim = int(tower_inputs[0].shape[-1]) - queries = tf.stack([Dense(dim)(x) for x in tower_inputs], axis=1) - keys = tf.stack([Dense(dim)(x) for x in tower_inputs], axis=1) - values = tf.stack([Dense(dim)(x) for x in tower_inputs], axis=1) - attn_cfg = seq_encoder_pb2.Attention() - attn_cfg.use_scale = True - params = Parameter.make_from_pb(attn_cfg) - attention_layer = Attention(params, name='AITM_%s' % tower_name) - result = attention_layer([queries, values, keys]) - relation_fea = result[:, 0, :] - relation_features[tower_name] = relation_fea - else: - relation_fea = tower_features[tower_name] + relation_fea = tower_features[tower_name] - output_logits = tf.layers.dense( - relation_fea, - task_tower_cfg.num_class, - kernel_regularizer=self._l2_reg, - name=tower_name + '/output') + output_logits = tf.layers.dense( + relation_fea, + task_tower_cfg.num_class, + kernel_regularizer=self._l2_reg, + name=tower_name + '/output') tower_outputs[tower_name] = output_logits self._add_to_prediction_dict(tower_outputs) diff --git a/easy_rec/python/model/rank_model.py b/easy_rec/python/model/rank_model.py index fc8e5214c..a144b999a 100644 --- a/easy_rec/python/model/rank_model.py +++ b/easy_rec/python/model/rank_model.py @@ -218,58 +218,59 @@ def _build_loss_impl(self, def build_loss_graph(self): loss_dict = {} - if len(self._losses) == 0: - loss_dict = self._build_loss_impl( - self._loss_type, - label_name=self._label_name, - loss_weight=self._sample_weight, - num_class=self._num_class) - else: - strategy = self._base_model_config.loss_weight_strategy - loss_weight = [1.0] - if strategy == self._base_model_config.Random and len(self._losses) > 1: - weights = tf.random_normal([len(self._losses)]) - loss_weight = tf.nn.softmax(weights) - for i, loss in enumerate(self._losses): - loss_param = loss.WhichOneof('loss_param') - if loss_param is not None: - loss_param = getattr(loss, loss_param) - loss_ops = self._build_loss_impl( - loss.loss_type, + with tf.name_scope('loss'): + if len(self._losses) == 0: + loss_dict = self._build_loss_impl( + self._loss_type, label_name=self._label_name, loss_weight=self._sample_weight, - num_class=self._num_class, - loss_name=loss.loss_name, - loss_param=loss_param) - for loss_name, loss_value in loss_ops.items(): - if strategy == self._base_model_config.Fixed: - loss_dict[loss_name] = loss_value * loss.weight - elif strategy == self._base_model_config.Uncertainty: - if loss.learn_loss_weight: - uncertainty = tf.Variable( - 0, name='%s_loss_weight' % loss_name, dtype=tf.float32) - tf.summary.scalar('loss/%s_uncertainty' % loss_name, uncertainty) - if loss.loss_type in {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}: - loss_dict[loss_name] = 0.5 * tf.exp( - -uncertainty) * loss_value + 0.5 * uncertainty + num_class=self._num_class) + else: + strategy = self._base_model_config.loss_weight_strategy + loss_weight = [1.0] + if strategy == self._base_model_config.Random and len(self._losses) > 1: + weights = tf.random_normal([len(self._losses)]) + loss_weight = tf.nn.softmax(weights) + for i, loss in enumerate(self._losses): + loss_param = loss.WhichOneof('loss_param') + if loss_param is not None: + loss_param = getattr(loss, loss_param) + loss_ops = self._build_loss_impl( + loss.loss_type, + label_name=self._label_name, + loss_weight=self._sample_weight, + num_class=self._num_class, + loss_name=loss.loss_name, + loss_param=loss_param) + for loss_name, loss_value in loss_ops.items(): + if strategy == self._base_model_config.Fixed: + loss_dict[loss_name] = loss_value * loss.weight + elif strategy == self._base_model_config.Uncertainty: + if loss.learn_loss_weight: + uncertainty = tf.Variable( + 0, name='%s_loss_weight' % loss_name, dtype=tf.float32) + tf.summary.scalar('%s_uncertainty' % loss_name, uncertainty) + if loss.loss_type in { + LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS + }: + loss_dict[loss_name] = 0.5 * tf.exp( + -uncertainty) * loss_value + 0.5 * uncertainty + else: + loss_dict[loss_name] = tf.exp( + -uncertainty) * loss_value + 0.5 * uncertainty else: - loss_dict[loss_name] = tf.exp( - -uncertainty) * loss_value + 0.5 * uncertainty + loss_dict[loss_name] = loss_value * loss.weight + elif strategy == self._base_model_config.Random: + loss_dict[loss_name] = loss_value * loss_weight[i] else: - loss_dict[loss_name] = loss_value * loss.weight - elif strategy == self._base_model_config.Random: - loss_dict[loss_name] = loss_value * loss_weight[i] - else: - raise ValueError('Unsupported loss weight strategy: ' + - strategy.Name) - - self._loss_dict.update(loss_dict) - - # build kd loss - kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict, - self._labels, self._feature_dict) - self._loss_dict.update(kd_loss_dict) - + raise ValueError('Unsupported loss weight strategy: ' + + strategy.Name) + self._loss_dict.update(loss_dict) + # build kd loss + kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict, + self._labels, + self._feature_dict) + self._loss_dict.update(kd_loss_dict) return self._loss_dict def _build_metric_impl(self, diff --git a/easy_rec/python/ops/1.12/libcustom_ops.so b/easy_rec/python/ops/1.12/libcustom_ops.so index 9014d3c9b..6d094f598 100755 Binary files a/easy_rec/python/ops/1.12/libcustom_ops.so and b/easy_rec/python/ops/1.12/libcustom_ops.so differ diff --git a/easy_rec/python/ops/1.12_pai/libcustom_ops.so b/easy_rec/python/ops/1.12_pai/libcustom_ops.so index 2729628bf..2676f0a24 100755 Binary files a/easy_rec/python/ops/1.12_pai/libcustom_ops.so and b/easy_rec/python/ops/1.12_pai/libcustom_ops.so differ diff --git a/easy_rec/python/ops/1.15/libcustom_ops.so b/easy_rec/python/ops/1.15/libcustom_ops.so index 6fac5578f..5023cfe47 100755 Binary files a/easy_rec/python/ops/1.15/libcustom_ops.so and b/easy_rec/python/ops/1.15/libcustom_ops.so differ diff --git a/easy_rec/python/protos/keras_layer.proto b/easy_rec/python/protos/keras_layer.proto index 8a12207a4..04ece4eb6 100644 --- a/easy_rec/python/protos/keras_layer.proto +++ b/easy_rec/python/protos/keras_layer.proto @@ -34,5 +34,7 @@ message KerasLayer { MultiHeadAttention multi_head_attention = 23; Transformer transformer = 24; TextEncoder text_encoder = 25; + WeightedGate gate = 26; + AITMTower aitm = 27; } } diff --git a/easy_rec/python/protos/layer.proto b/easy_rec/python/protos/layer.proto index a76dfce34..a0438f071 100644 --- a/easy_rec/python/protos/layer.proto +++ b/easy_rec/python/protos/layer.proto @@ -35,8 +35,11 @@ message NaryDisEmbedding { repeated uint32 carries = 2; optional float multiplier = 3 [default = 1.0]; optional string intra_ary_pooling = 4 [default = 'sum']; + // for now, inter_ary_pooling not support yet optional string inter_ary_pooling = 5 [default = 'concat']; optional bool output_3d_tensor = 6 [default = false]; + optional bool output_tensor_list = 7; + optional uint32 num_replicas = 8 [default = 1]; } message SENet { @@ -82,6 +85,13 @@ message MMoELayer { optional uint32 num_expert = 3; } +// used in CDN model +message WeightedGate { + optional uint32 weight_index = 1 [default = 0]; + optional MLP mlp = 2; +} + +// used in PPNet message GateNN { optional uint32 output_dim = 1; optional uint32 hidden_dim = 2; @@ -128,3 +138,9 @@ message MappedDotProduct { optional int32 print_first_n = 6 [default = 0]; optional int32 summarize = 7; } + +message AITMTower { + optional uint32 project_dim = 1; + optional MLP transfer_mlp = 2; + optional bool stop_gradient = 3 [default = true]; +} diff --git a/easy_rec/python/protos/tower.proto b/easy_rec/python/protos/tower.proto index fcaedafcd..aa5622e20 100644 --- a/easy_rec/python/protos/tower.proto +++ b/easy_rec/python/protos/tower.proto @@ -76,12 +76,8 @@ message BayesTaskTower { repeated Loss losses = 15; // whether to use sample weight in this tower required bool use_sample_weight = 16 [default = true]; - // whether to use AIT module - optional bool use_ait_module = 17 [default = false]; - // set this when the dimensions of last layer of towers are not equal - optional uint32 ait_project_dim = 18; // field name for indicating the sample space for this task - optional string task_space_indicator_name = 19; + optional string task_space_indicator_name = 17; // field value for indicating the sample space for this task - optional string task_space_indicator_value = 20; + optional string task_space_indicator_value = 18; }; diff --git a/easy_rec/python/utils/tf_utils.py b/easy_rec/python/utils/tf_utils.py index 59258b9dc..24f47a94a 100644 --- a/easy_rec/python/utils/tf_utils.py +++ b/easy_rec/python/utils/tf_utils.py @@ -1,6 +1,9 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. """Common functions used for odps input.""" +import json +import os + import tensorflow as tf from easy_rec.python.protos.dataset_pb2 import DatasetConfig @@ -9,6 +12,16 @@ tf = tf.compat.v1 +def get_ps_num_from_tf_config(): + tf_config = os.environ.get('TF_CONFIG') + if tf_config: + tf_config_json = json.loads(tf_config) + cluster = tf_config_json.get('cluster', {}) + ps_hosts = cluster.get('ps', []) + return len(ps_hosts) + return 0 + + def get_tf_type(field_type): type_map = { DatasetConfig.INT32: tf.int32, diff --git a/easy_rec/version.py b/easy_rec/version.py index f6722ee1a..7da645311 100644 --- a/easy_rec/version.py +++ b/easy_rec/version.py @@ -1,4 +1,4 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -__version__ = '0.8.3' +__version__ = '0.8.4' diff --git a/examples/configs/dlrm_on_criteo_with_narydis.config b/examples/configs/dlrm_on_criteo_with_narydis.config index 469734c23..121fd98c2 100644 --- a/examples/configs/dlrm_on_criteo_with_narydis.config +++ b/examples/configs/dlrm_on_criteo_with_narydis.config @@ -534,6 +534,7 @@ model_config: { embedding_dim: 8 carries: [2, 9] multiplier: 1e6 + output_tensor_list: true } } } diff --git a/samples/model_config/aitm_on_taobao.config b/samples/model_config/aitm_on_taobao.config index c67f1d677..9131a41a7 100644 --- a/samples/model_config/aitm_on_taobao.config +++ b/samples/model_config/aitm_on_taobao.config @@ -243,7 +243,7 @@ model_config { } backbone { blocks { - name: "mlp" + name: "share_bottom" inputs { feature_group_name: "all" } @@ -254,6 +254,49 @@ model_config { } } } + blocks { + name: "ctr_tower" + inputs { + block_name: "share_bottom" + } + keras_layer { + class_name: 'MLP' + mlp { + hidden_units: 128 + } + } + } + blocks { + name: "cvr_tower" + inputs { + block_name: "share_bottom" + } + keras_layer { + class_name: 'MLP' + mlp { + hidden_units: 128 + } + } + } + blocks { + name: "cvr_aitm" + inputs { + block_name: "cvr_tower" + } + inputs { + block_name: "ctr_tower" + } + merge_inputs_into_list: true + keras_layer { + class_name: "AITMTower" + aitm { + transfer_mlp { + hidden_units: 128 + } + } + } + } + output_blocks: ["ctr_tower", "cvr_aitm"] } model_params { task_towers { @@ -264,9 +307,8 @@ model_config { auc {} } dnn { - hidden_units: [256, 128] + hidden_units: 64 } - use_ait_module: true weight: 1.0 } task_towers { @@ -282,11 +324,8 @@ model_config { auc {} } dnn { - hidden_units: [256, 128] + hidden_units: 64 } - relation_tower_names: ["ctr"] - use_ait_module: true - ait_project_dim: 128 weight: 1.0 } l2_regularization: 1e-6