From ac9fb014111cde6bd48177acb274ab5bd31ced76 Mon Sep 17 00:00:00 2001 From: "weisu.yxd" Date: Thu, 14 Dec 2023 12:47:53 +0800 Subject: [PATCH] fix doc build problem --- docs/source/train.md | 50 +++++++++++++---------- easy_rec/python/model/multi_task_model.py | 4 +- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/docs/source/train.md b/docs/source/train.md index f233a89e0..cb78b546d 100644 --- a/docs/source/train.md +++ b/docs/source/train.md @@ -236,31 +236,39 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 多目标学习任务中,人工指定多个损失函数的静态权重通常不能获得最好的效果。EasyRec支持损失函数权重自适应学习,示例如下: ```protobuf - losses { - loss_type: CLASSIFICATION - learn_loss_weight: true - } - losses { - loss_type: BINARY_FOCAL_LOSS - learn_loss_weight: true - binary_focal_loss { - gamma: 2.0 - alpha: 0.85 - } - } - losses { - loss_type: PAIRWISE_FOCAL_LOSS - learn_loss_weight: true - pairwise_focal_loss { - session_name: "client_str" - hinge_margin: 1.0 - } - } + loss_weight_strategy: Uncertainty + losses { + loss_type: CLASSIFICATION + learn_loss_weight: true + } + losses { + loss_type: BINARY_FOCAL_LOSS + learn_loss_weight: true + binary_focal_loss { + gamma: 2.0 + alpha: 0.85 + } + } + losses { + loss_type: PAIRWISE_FOCAL_LOSS + learn_loss_weight: true + pairwise_focal_loss { + session_name: "client_str" + hinge_margin: 1.0 + } + } ``` 通过`learn_loss_weight`参数配置是否需要开启权重自适应学习,默认不开启。开启之后,`weight`参数不再生效。 -参考论文:《Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics》 +- loss_weight_strategy: Uncertainty + - 表示通过不确定性来度量损失函数的权重;目前在`learn_loss_weight: true`时必须要设置该值 +- loss_weight_strategy: Random + - 表示损失函数的权重设定为归一化的随机数 + +参考论文: +- 《 Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics 》 +- 《 [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/abs/2111.10603) 》 ## 训练命令 diff --git a/easy_rec/python/model/multi_task_model.py b/easy_rec/python/model/multi_task_model.py index ade76d5ab..2479d7b2a 100644 --- a/easy_rec/python/model/multi_task_model.py +++ b/easy_rec/python/model/multi_task_model.py @@ -10,6 +10,7 @@ from easy_rec.python.model.rank_model import RankModel from easy_rec.python.protos import tower_pb2 from easy_rec.python.protos.loss_pb2 import LossType +from easy_rec.python.protos.easy_rec_model_pb2 import EasyRecModel if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -188,7 +189,8 @@ def get_learnt_loss(self, loss_type, name, value): else: return tf.exp(-uncertainty) * value + 0.5 * uncertainty else: - raise ValueError('Unsupported loss weight strategy: ' + strategy.Name) + strategy_name = EasyRecModel.LossWeightStrategy.Name(strategy) + raise ValueError('Unsupported loss weight strategy: ' + strategy_name) def build_loss_graph(self): """Build loss graph for multi task model."""