Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug of text cnn, add TextCNN component and model #462

Merged
merged 2 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/component/backbone.md
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ DBMTL模型需要在`model_params`里为每个子任务的Tower配置`relation_d
```protobuf
model_config: {
model_name: 'MaskNet + PPNet + MMoE'
model_class: 'RankModel'
model_class: "MultiTaskModel"
feature_groups: {
group_name: 'memorize'
feature_names: 'user_id'
Expand All @@ -971,6 +971,7 @@ model_config: {
name: "mask_net"
inputs {
feature_group_name: "general"
input_fn: "lambda x: [x, x]"
}
repeat {
num_repeat: 3
Expand Down Expand Up @@ -1104,6 +1105,7 @@ MovieLens-1M数据集效果:
| Gate | 门控 | 多个输入的加权求和 | [Cross Decoupling Network](../models/cdn.html#id2) |
| PeriodicEmbedding | 周期激活函数 | 数值特征Embedding | [案例5](#dlrm-embedding) |
| AutoDisEmbedding | 自动离散化 | 数值特征Embedding | [dlrm_on_criteo_with_autodis.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_on_criteo_with_autodis.config) |
| TextCNN | 文本卷积 | 提取文本序列的特征 | [text_cnn_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/text_cnn_on_movielens.config) |

**备注**:Gate组件的第一个输入是权重向量,后面的输入拼凑成一个列表,权重向量的长度应等于列表的长度

Expand Down
12 changes: 12 additions & 0 deletions docs/source/component/component.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@
| output_tensor_list | bool | false | 是否同时输出embedding列表 |
| output_3d_tensor | bool | false | 是否同时输出3d tensor, `output_tensor_list=true`时该参数不生效 |

- TextCNN

| 参数 | 类型 | 默认值 | 说明 |
| ------------------- | ------------ | ---- | ---------------- |
| num_filters | list<uint32> | | 卷积核个数列表 |
| filter_sizes | list<uint32> | | 卷积核步长列表 |
| activation | string | relu | 卷积操作的激活函数 |
| pad_sequence_length | uint32 | | 序列补齐或截断的长度 |
| mlp | MLP | | protobuf message |

备注:pad_sequence_length 参数必须要配置,否则模型predict的分数可能不稳定

## 2.特征交叉组件

- Bilinear
Expand Down
4 changes: 4 additions & 0 deletions docs/source/feature/feature.rst
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,16 @@ Sequence类特征格式一般为“XX\|XX\|XX”,如用户行为序列特征
text_cnn: {
filter_sizes: [2, 3, 4]
num_filters: [16, 8, 8]
pad_sequence_length: 128
activation: 'relu'
}
}
}

- num_filters: 卷积核个数列表
- filter_sizes: 卷积核步长列表
- pad_sequence_length: 序列补齐或截断的长度
- activation: 卷积操作的激活函数,默认为relu

TextCNN网络是2014年提出的用来做文本分类的卷积神经网络,由于其结构简单、效果好,在文本分类、推荐等NLP领域应用广泛。
从直观上理解,TextCNN通过一维卷积来获取句子中`N gram`的特征表示。
Expand Down
79 changes: 79 additions & 0 deletions docs/source/models/text_cnn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# TextCNN

### 简介

TextCNN网络是2014年提出的用来做文本分类的卷积神经网络,由于其结构简单、效果好,在文本分类、推荐等NLP领域应用广泛。
从直观上理解,TextCNN通过一维卷积来获取句子中`N gram`的特征表示。

在推荐模型中,可以用TextCNN网络来提取文本类型的特征。

### 配置说明

```protobuf
model_config: {
model_name: 'TextCNN'
model_class: 'RankModel'
feature_groups: {
group_name: 'text_seq'
feature_names: 'title'
wide_deep: DEEP
}
backbone {
blocks {
name: 'text_seq'
inputs {
feature_group_name: 'text_seq'
}
input_layer {
output_seq_and_normal_feature: true
}
}
blocks {
name: 'textcnn'
inputs {
block_name: 'text_seq'
}
keras_layer {
class_name: 'TextCNN'
text_cnn {
filter_sizes: [2, 3, 4]
num_filters: [16, 8, 8]
pad_sequence_length: 14
mlp {
hidden_units: [256, 128, 64]
}
}
}
}
}
model_params {
l2_regularization: 1e-6
}
embedding_regularization: 1e-6
}
```

- model_name: 任意自定义字符串,仅有注释作用
- model_class: 'RankModel', 不需要修改, 通过组件化方式搭建的单目标排序模型都叫这个名字
- feature_groups: 配置一组特征。
- backbone: 通过组件化的方式搭建的主干网络,[参考文档](../component/backbone.md)
- blocks: 由多个`组件块`组成的一个有向无环图(DAG),框架负责按照DAG的拓扑排序执行个`组件块`关联的代码逻辑,构建TF Graph的一个子图
- name/inputs: 每个`block`有一个唯一的名字(name),并且有一个或多个输入(inputs)和输出
- keras_layer: 加载由`class_name`指定的自定义或系统内置的keras layer,执行一段代码逻辑;[参考文档](../component/backbone.md#keraslayer)
- TextCNN: 调用TextCNN组件。组件的参数,详见[参考文档](../component/component.md#id2)
- num_filters: 卷积核个数列表
- filter_sizes: 卷积核步长列表
- pad_sequence_length: 序列补齐或截断的长度
- activation: 卷积操作的激活函数,默认为relu
- concat_blocks: DAG的输出节点由`concat_blocks`配置项定义,如果不配置`concat_blocks`,框架会自动拼接DAG的所有叶子节点并输出。
- model_params:
- l2_regularization: (可选) 对DNN参数的regularization, 减少overfit
- embedding_regularization: 对embedding部分加regularization, 减少overfit

### 示例Config

[text_cnn_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/samples/model_config/text_cnn_on_movielens.config)

### 参考论文

[Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1408.5882)
9 changes: 5 additions & 4 deletions easy_rec/python/layers/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from easy_rec.python.feature_column.feature_group import FeatureGroup
from easy_rec.python.layers import sequence_feature_layer
from easy_rec.python.layers import variational_dropout_layer
from easy_rec.python.layers.common_layers import text_cnn
from easy_rec.python.layers.keras import TextCNN
from easy_rec.python.layers.utils import Parameter
from easy_rec.python.protos.feature_config_pb2 import WideOrDeep
from easy_rec.python.utils import conditional
from easy_rec.python.utils import shape_utils
Expand Down Expand Up @@ -285,9 +286,9 @@ def single_call_input_layer(self,
seq_features.append(seq_feature)
cols_to_output_tensors[column] = seq_feature
elif sequence_combiner.WhichOneof('combiner') == 'text_cnn':
num_filters = sequence_combiner.text_cnn.num_filters
filter_sizes = sequence_combiner.text_cnn.filter_sizes
cnn_feature = text_cnn(seq_feature, filter_sizes, num_filters)
params = Parameter.make_from_pb(sequence_combiner.text_cnn)
text_cnn_layer = TextCNN(params, name=column.name + '_text_cnn')
cnn_feature = text_cnn_layer((seq_feature, seq_len))
seq_features.append(cnn_feature)
cols_to_output_tensors[column] = cnn_feature
else:
Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/layers/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .blocks import MLP
from .blocks import Gate
from .blocks import Highway
from .blocks import TextCNN
from .bst import BST
from .data_augment import SeqAugment
from .din import DIN
Expand Down
55 changes: 55 additions & 0 deletions easy_rec/python/layers/keras/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import tensorflow as tf

from easy_rec.python.layers.keras.activation import activation_layer
from easy_rec.python.layers.utils import Parameter
from easy_rec.python.utils.shape_utils import pad_or_truncate_sequence
from easy_rec.python.utils.tf_utils import add_elements_to_collection

if tf.__version__ >= '2.0':
Expand Down Expand Up @@ -161,3 +163,56 @@ def call(self, inputs, **kwargs):
output += weights[:, j, None] * x
j += 1
return output


class TextCNN(tf.keras.layers.Layer):
"""Text CNN Model.

References
- [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882)
"""

def __init__(self, params, name='text_cnn', reuse=None, **kwargs):
super(TextCNN, self).__init__(name, **kwargs)
self.config = params.get_pb_config()
self.pad_seq_length = self.config.pad_sequence_length
if self.pad_seq_length <= 0:
logging.warning(
'run text cnn with pad_sequence_length <= 0, the predict of model may be unstable'
)
self.conv_layers = []
self.pool_layer = tf.keras.layers.GlobalMaxPool1D()
self.concat_layer = tf.keras.layers.Concatenate(axis=-1)
for size, filters in zip(self.config.filter_sizes, self.config.num_filters):
conv = tf.keras.layers.Conv1D(
filters=int(filters),
kernel_size=int(size),
activation=self.config.activation)
self.conv_layers.append(conv)
if self.config.HasField('mlp'):
p = Parameter.make_from_pb(self.config.mlp)
p.l2_regularizer = params.l2_regularizer
self.mlp = MLP(p, name='mlp', reuse=reuse)
else:
self.mlp = None

def call(self, inputs, training=None, **kwargs):
"""Input shape: 3D tensor with shape: `(batch_size, steps, input_dim)."""
assert isinstance(inputs, (list, tuple))
assert len(inputs) >= 2
seq_emb, seq_len = inputs[:2]

if self.pad_seq_length > 0:
seq_emb, seq_len = pad_or_truncate_sequence(seq_emb, seq_len,
self.pad_seq_length)
pooled_outputs = []
for layer in self.conv_layers:
conv = layer(seq_emb)
pooled = self.pool_layer(conv)
pooled_outputs.append(pooled)
net = self.concat_layer(pooled_outputs)
if self.mlp is not None:
output = self.mlp(net)
else:
output = net
return output
8 changes: 2 additions & 6 deletions easy_rec/python/protos/feature_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package protos;

import "easy_rec/python/protos/hyperparams.proto";
import "easy_rec/python/protos/dnn.proto";
import "easy_rec/python/protos/layer.proto";
enum WideOrDeep {
DEEP = 0;
WIDE = 1;
Expand All @@ -15,16 +16,11 @@ message AttentionCombiner {
message MultiHeadAttentionCombiner {
}

message TextCnnCombiner {
repeated uint32 filter_sizes = 1;
repeated uint32 num_filters = 2;
}

message SequenceCombiner {
oneof combiner {
AttentionCombiner attention = 1;
MultiHeadAttentionCombiner multi_head_attention = 2;
TextCnnCombiner text_cnn = 3;
TextCNN text_cnn = 3;
}
}

Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/protos/keras_layer.proto
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ message KerasLayer {
MMoELayer mmoe = 14;
SequenceAugment seq_aug = 15;
PPNet ppnet = 16;
TextCNN text_cnn = 17;
}
}
8 changes: 8 additions & 0 deletions easy_rec/python/protos/layer.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,11 @@ message PPNet {
required string mode = 3 [default = 'eager'];
optional bool full_gate_input = 4 [default = true];
}

message TextCNN {
repeated uint32 filter_sizes = 1;
repeated uint32 num_filters = 2;
required uint32 pad_sequence_length = 3;
optional string activation = 4 [default = 'relu'];
optional MLP mlp = 5;
}
22 changes: 22 additions & 0 deletions easy_rec/python/utils/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,25 @@ def keep(seq_embed, seq_length):

return tf.cond(max_seq_len > limited_len, lambda: truncate(seq_emb, seq_len),
lambda: keep(seq_emb, seq_len))


def pad_or_truncate_sequence(seq_emb, seq_len, fixed_len):
padding_length = fixed_len - tf.shape(seq_emb)[1]

def padding():
paddings = tf.stack([[0, 0], [0, padding_length], [0, 0]])
padded = tf.pad(seq_emb, paddings)
return padded, seq_len

def truncate():
sliced = tf.slice(seq_emb, [0, 0, 0], [-1, fixed_len, -1])
length = tf.where(seq_len < fixed_len, seq_len,
tf.ones_like(seq_len) *
fixed_len) if seq_len is not None else None
return sliced, length

def keep():
return seq_emb, seq_len

return tf.cond(padding_length > 0, padding,
lambda: tf.cond(padding_length < 0, truncate, keep))
1 change: 1 addition & 0 deletions examples/configs/contrastive_learning_on_movielens.config
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ feature_config: {
text_cnn: {
filter_sizes: [2, 3, 4]
num_filters: [16, 8, 8]
pad_sequence_length: 14
}
}
}
Expand Down
1 change: 1 addition & 0 deletions examples/configs/dcn_backbone_on_movielens.config
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ feature_config: {
text_cnn: {
filter_sizes: [2, 3, 4]
num_filters: [16, 8, 8]
pad_sequence_length: 14
}
}
}
Expand Down
1 change: 1 addition & 0 deletions examples/configs/dcn_on_movielens.config
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ feature_config: {
text_cnn: {
filter_sizes: [2, 3, 4]
num_filters: [16, 8, 8]
pad_sequence_length: 14
}
}
}
Expand Down
1 change: 1 addition & 0 deletions examples/configs/deepfm_on_movielens.config
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ feature_config: {
text_cnn: {
filter_sizes: [2, 3, 4]
num_filters: [8, 4, 4]
pad_sequence_length: 14
}
}
}
Expand Down
1 change: 1 addition & 0 deletions examples/configs/fibinet_on_movielens.config
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ feature_config: {
text_cnn: {
filter_sizes: [2, 3, 4]
num_filters: [16, 8, 8]
pad_sequence_length: 14
}
}
}
Expand Down
1 change: 1 addition & 0 deletions examples/configs/masknet_on_movielens.config
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ feature_config: {
text_cnn: {
filter_sizes: [2, 3, 4]
num_filters: [16, 8, 8]
pad_sequence_length: 14
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions examples/configs/mlp_on_movielens.config
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ feature_config: {
text_cnn: {
filter_sizes: [2, 3, 4]
num_filters: [16, 8, 8]
pad_sequence_length: 14
}
}
}
Expand All @@ -158,6 +159,7 @@ model_config: {
feature_names: 'gender'
feature_names: 'year'
feature_names: 'genres'
feature_names: 'title'
wide_deep: DEEP
}
backbone {
Expand Down
1 change: 1 addition & 0 deletions examples/configs/multi_tower_on_movielens.config
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ feature_config: {
text_cnn: {
filter_sizes: [2, 3, 4]
num_filters: [16, 8, 8]
pad_sequence_length: 14
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ feature_config: {
text_cnn: {
filter_sizes: [2, 3, 4]
num_filters: [16, 8, 8]
pad_sequence_length: 14
}
}
}
Expand Down
1 change: 1 addition & 0 deletions examples/configs/wide_and_deep_on_movielens.config
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ feature_config: {
text_cnn: {
filter_sizes: [2, 3, 4]
num_filters: [16, 8, 8]
pad_sequence_length: 14
}
}
}
Expand Down
Loading
Loading