Skip to content

Commit

Permalink
add support for rank distillation
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Sep 9, 2024
1 parent f015215 commit 88adf24
Show file tree
Hide file tree
Showing 15 changed files with 469 additions and 290 deletions.
19 changes: 10 additions & 9 deletions docs/source/component/custom_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
- 参考官方示例:[TensorFlow Custom Op](https://github.com/tensorflow/custom-op/)
- 注意:自定义Op的编译依赖tf版本需要与执行时的tf版本保持一致
- 您可能需要为离线训练 与 在线推理服务 编译两个不同依赖环境的动态库
- 在PAI平台上需要依赖 tf 1.12 版本编译
- 在PAI平台上需要依赖 tf 1.12 版本编译(先下载pai-tf的官方镜像)
- 在EAS的 [EasyRec Processor](https://help.aliyun.com/zh/pai/user-guide/easyrec) 中使用自定义Op需要依赖 tf 2.10.1 编译
1.`EasyRec`中使用自定义Op的步骤
1. 下载EasyRec的最新[源代码](https://github.com/alibaba/EasyRec)
Expand Down Expand Up @@ -45,6 +45,8 @@ pai -name easy_rec_ext

## 自定义Op的示例

使用自定义OP求两段输入文本的Term匹配率

```protobuf
feature_config: {
...
Expand Down Expand Up @@ -89,17 +91,16 @@ model_config: {
}
}
blocks {
name: 'edit_distance'
name: 'match_ratio'
inputs {
block_name: 'text'
}
keras_layer {
class_name: 'EditDistance'
st_params {
fields {
key: 'text_encoding'
value: { string_value: 'latin' }
}
class_name: 'OverlapFeature'
overlap {
separator: " "
default_value: "0"
methods: "query_common_ratio"
}
}
}
Expand All @@ -109,7 +110,7 @@ model_config: {
feature_group_name: 'features'
}
inputs {
block_name: 'edit_distance'
block_name: 'match_ratio'
}
keras_layer {
class_name: 'MLP'
Expand Down
41 changes: 40 additions & 1 deletion docs/source/kd.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

- label_is_logits: 目标是logits, 还是probs, 默认是logits

- loss_type: loss的类型, 可以是CROSS_ENTROPY_LOSS或者L2_LOSS
- loss_type: loss的类型, 可以是CROSS_ENTROPY_LOSS、L2_LOSS、BINARY_CROSS_ENTROPY_LOSS、KL_DIVERGENCE_LOSS、PAIRWISE_HINGE_LOSS、LISTWISE_RANK_LOSS等

- loss_weight: loss的权重, 默认是1.0

Expand Down Expand Up @@ -63,6 +63,45 @@ model_config {
}
```

除了常规的从teacher模型的预测结果里"蒸馏"知识到student模型,在搜推场景中更加推荐采用基于pairwise或者listwise的方式从teacher模型学习
其对不同item的排序(学习对item预估结果的偏序关系),示例如下:

- pairwise 知识蒸馏

```protobuf
kd {
loss_name: 'ctcvr_rank_loss'
soft_label_name: 'pay_logits'
pred_name: 'logits'
loss_type: PAIRWISE_HINGE_LOSS
loss_weight: 1.0
pairwise_hinge_loss {
session_name: "raw_query"
use_exponent: false
use_label_margin: true
}
}
```

- listwise 知识蒸馏

```protobuf
kd {
loss_name: 'ctcvr_rank_loss'
soft_label_name: 'pay_logits'
pred_name: 'logits'
loss_type: LISTWISE_RANK_LOSS
loss_weight: 1.0
listwise_rank_loss {
session_name: "raw_query"
temperature: 3.0
label_is_logits: true
}
}
```

可以为损失函数配置参数,配置方法参考[损失函数](models/loss.md)参数。

### 训练命令

训练命令不改变, 详细参考[模型训练](./train.md)
Expand Down
1 change: 1 addition & 0 deletions docs/source/models/cl4srec.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ model_config: {
- use_package_input: 当`package`的输入是动态的时,设置该输入占位符,表示当前`block`的输入由调用`package`时指定
- keras_layer: 加载由`class_name`指定的自定义或系统内置的keras layer,执行一段代码逻辑;[参考文档](../component/backbone.md#keraslayer)
- SeqAugment: 序列数据增强的组件,参数详见[参考文档](../component/component.md#id5)
- SeqAugmentOps: `class_name`指定为`SeqAugmentOps`可以使用自定义OP版本的序列数据增加组件,性能更好
- AuxiliaryLoss: 计算辅助任务损失函数的组件,参数详见[参考文档](../component/component.md#id7)
- concat_blocks: DAG的输出节点由`concat_blocks`配置项定义,如果不配置`concat_blocks`,框架会自动拼接DAG的所有叶子节点并输出。
- model_params:
Expand Down
33 changes: 28 additions & 5 deletions docs/source/models/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,21 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
| L2_LOSS | 平方损失 |
| SIGMOID_L2_LOSS | 对sigmoid函数的结果计算平方损失 |
| CROSS_ENTROPY_LOSS | log loss 负对数损失 |
| BINARY_CROSS_ENTROPY_LOSS | 仅用在知识蒸馏中的BCE损失 |
| KL_DIVERGENCE_LOSS | 仅用在知识蒸馏中的KL散度损失 |
| CIRCLE_LOSS | CoMetricLearningI2I模型专用 |
| MULTI_SIMILARITY_LOSS | CoMetricLearningI2I模型专用 |
| SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING | 自动负采样版本的多分类softmax_cross_entropy,用在二分类任务中 |
| BINARY_FOCAL_LOSS | 支持困难样本挖掘和类别平衡的focal loss |
| SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING | 自动负采样版本的多分类softmax_cross_entropy,用在二分类任务中 |
| BINARY_FOCAL_LOSS | 支持困难样本挖掘和类别平衡的focal loss |
| PAIR_WISE_LOSS | 以优化全局AUC为目标的rank loss |
| PAIRWISE_FOCAL_LOSS | pair粒度的focal loss, 支持自定义pair分组 |
| PAIRWISE_LOGISTIC_LOSS | pair粒度的logistic loss, 支持自定义pair分组 |
| PAIRWISE_HINGE_LOSS | pair粒度的hinge loss, 支持自定义pair分组 |
| JRC_LOSS | 二分类 + listwise ranking loss |
| F1_REWEIGHTED_LOSS | 可以调整二分类召回率和准确率相对权重的损失函数,可有效对抗正负样本不平衡问题 |
| ORDER_CALIBRATE_LOSS | 使用目标依赖关系校正预测结果的辅助损失函数,详见[AITM](aitm.md)模型 |
| F1_REWEIGHTED_LOSS | 可以调整二分类召回率和准确率相对权重的损失函数,可有效对抗正负样本不平衡问题 |
| ORDER_CALIBRATE_LOSS | 使用目标依赖关系校正预测结果的辅助损失函数,详见[AITM](aitm.md)模型 |
| LISTWISE_RANK_LOSS | listwise的排序损失 |
| LISTWISE_DISTILL_LOSS | 用来蒸馏给定list排序的损失函数,与listwise rank loss 比较类似 |

- 说明:SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING
- 支持参数配置,升级为 [support vector guided softmax loss](https://128.84.21.199/abs/1812.11317)
Expand Down Expand Up @@ -99,6 +104,16 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
- margin: 当pair的logit之差减去该参数值后再参与计算,即正负样本的logit之差至少要大于margin,默认值为0
- temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0

- PAIRWISE_HINGE_LOSS 的参数配置

- session_name: pair分组的字段名,比如user_id
- temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0
- margin: 当pair的logit之差大于该参数值时,当前样本的loss为0,默认值为1.0
- ohem_ratio: 困难样本的百分比,只有部分困难样本参与loss计算,默认值为1.0
- label_is_logits: bool, 标记label是否为teacher模型的输出logits,默认为true
- use_label_margin: bool, 是否使用输入pair的label的diff作为margin,设置为true时`margin`参数不生效,默认为true
- use_exponent: bool, 是否对模型的输出做pairwise的指数变化,默认为false

备注:上述 PAIRWISE\_\*\_LOSS 都是在mini-batch内构建正负样本pair,目标是让正负样本pair的logit相差尽可能大

- BINARY_FOCAL_LOSS 的参数配置
Expand All @@ -115,6 +130,13 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
- 参考论文:《 [Joint Optimization of Ranking and Calibration with Contextualized Hybrid Model](https://arxiv.org/pdf/2208.06164.pdf)
- 使用示例: [dbmtl_with_jrc_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/dbmtl_on_taobao_with_multi_loss.config)

- LISTWISE_RANK_LOSS 的参数配置

- temperature: 温度系数,logit除以该参数值后再参与计算,默认值为1.0
- session_name: list分组的字段名,比如user_id
- label_is_logits: bool, 标记label是否为teacher模型的输出logits,默认为false
- scale_logits: bool, 是否需要对模型的logits进行线性缩放,默认为false

排序模型同时使用多个损失函数的完整示例:
[cmbf_with_multi_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/cmbf_with_multi_loss.config)

Expand Down Expand Up @@ -159,5 +181,6 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
### 参考论文:

- 《 Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics 》
- [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/abs/2111.10603)
- [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/abs/2111.10603)
- [AITM: Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising](https://arxiv.org/pdf/2105.08489.pdf)
- [Pairwise Ranking Distillation for Deep Face Recognition](https://ceur-ws.org/Vol-2744/paper30.pdf)
2 changes: 1 addition & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import os
import threading
import time
from distutils.version import LooseVersion

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand Down
12 changes: 6 additions & 6 deletions easy_rec/python/layers/keras/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, params, name='sequence_aug', reuse=None, **kwargs):
super(SeqAugmentOps, self).__init__(name=name, **kwargs)
self.reuse = reuse
self.seq_aug_params = params.get_pb_config()
self.seq_augment = custom_ops.my_seq_augment
self.seq_augment = custom_ops.seq_augment

def build(self, input_shape):
assert len(input_shape) >= 2, 'SeqAugmentOps must has at least two inputs'
Expand All @@ -60,11 +60,11 @@ def call(self, inputs, training=None, **kwargs):
assert isinstance(inputs, (list, tuple))
seq_input, seq_len = inputs[:2]

x = self.seq_augment(seq_input, seq_len, self.mask_emb,
self.seq_aug_params.crop_rate,
self.seq_aug_params.reorder_rate,
self.seq_aug_params.mask_rate)
return x
aug_seq, aug_len = self.seq_augment(seq_input, seq_len, self.mask_emb,
self.seq_aug_params.crop_rate,
self.seq_aug_params.reorder_rate,
self.seq_aug_params.mask_rate)
return aug_seq, aug_len


class TextNormalize(Layer):
Expand Down
10 changes: 5 additions & 5 deletions easy_rec/python/model/dssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def build_predict_graph(self):
name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))

if self._model_config.simi_func == Similarity.COSINE:
user_tower_emb = self.norm(user_tower_emb)
item_tower_emb = self.norm(item_tower_emb)
temperature = self._model_config.temperature
else:
temperature = 1.0
user_tower_emb = self.norm(user_tower_emb)
item_tower_emb = self.norm(item_tower_emb)
temperature = self._model_config.temperature
else:
temperature = 1.0

user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature
if self._model_config.scale_simi:
Expand Down
Binary file added easy_rec/python/ops/1.12/libcustom_ops.so
Binary file not shown.
Binary file modified easy_rec/python/ops/1.12_pai/libcustom_ops.so
Binary file not shown.
Binary file added easy_rec/python/ops/1.15/libcustom_ops.so
Binary file not shown.
Binary file added easy_rec/python/ops/2.12/libcustom_ops.so
Binary file not shown.
7 changes: 2 additions & 5 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import threading
import time
import unittest
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.platform import gfile

from easy_rec.python.main import predict
Expand Down Expand Up @@ -409,12 +409,9 @@ def test_highway(self):
'samples/model_config/highway_on_movielens.config', self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf(
LooseVersion(tf.__version__) >= LooseVersion('2.0.0'),
'EditDistanceOp only work before tf version == 2.0')
def test_custom_op(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/mlp_on_movielens_with_custom_op.config',
'samples/model_config/cl4srec_on_taobao_with_custom_op.config',
self._test_dir)
self.assertTrue(self._success)

Expand Down
Loading

0 comments on commit 88adf24

Please sign in to comment.