Skip to content

Commit

Permalink
dat model supports amm_loss_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-gecheng committed Jan 17, 2025
1 parent 8e2def1 commit a6da5c8
Show file tree
Hide file tree
Showing 13 changed files with 32 additions and 20 deletions.
2 changes: 1 addition & 1 deletion docker/Dockerfile_tf112
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ RUN cd /EasyRec && python setup.py install
RUN rm -rf /EasyRec
RUN python -c "import easy_rec; import pyhive; import datahub; import kafka"

COPY docker/hadoop_env.sh /opt/hadoop_env.sh
COPY docker/hadoop_env.sh /opt/hadoop_env.sh
2 changes: 1 addition & 1 deletion docker/Dockerfile_tf115
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ RUN cd /EasyRec && pip install .
RUN rm -rf /EasyRec
RUN python -c "import easy_rec; easy_rec.help(); import pyhive; import datahub; import kafka"

COPY docker/hadoop_env.sh /opt/hadoop_env.sh
COPY docker/hadoop_env.sh /opt/hadoop_env.sh
2 changes: 1 addition & 1 deletion docker/Dockerfile_tf212
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ RUN cd /EasyRec && python setup.py install
RUN rm -rf /EasyRec
# RUN python -c "import easy_rec; easy_rec.help(); import pyhive; import datahub; import kafka"

COPY docker/hadoop_env.sh /opt/hadoop_env.sh
COPY docker/hadoop_env.sh /opt/hadoop_env.sh
14 changes: 8 additions & 6 deletions docs/source/models/dssm_derivatives.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,26 @@ model_config:{
features: {
input_names: 'user_id'
feature_type: IdFeature
embedding_dim: 32
embedding_dim: 32 # user_id特征embedding维度
hash_bucket_size: 100000
}
features: {
input_names: 'adgroup_id'
feature_type: IdFeature
embedding_dim: 32
embedding_dim: 32 # item_id特征embedding维度
hash_bucket_size: 100000
}
.
.
.
feature_groups: {
group_name: 'user_id_augment'
group_name: 'user_id_augment' # 增加user_augment特征组,对user_id特征进行embedding作为辅助向量
feature_names: 'user_id'
wide_deep:DEEP
}
feature_groups: {
group_name: 'item_id_augment'
group_name: 'item_id_augment' # 增加item_augment特征组,对item_id特征进行embedding作为辅助向量
feature_names: 'adgroup_id'
wide_deep:DEEP
}
Expand All @@ -137,19 +137,21 @@ model_config:{
user_tower {
id: "user_id"
dnn {
hidden_units: [ 128, 32]
hidden_units: [ 128, 32] # 输出维度需要保证和item_augment特征组的embedding维度一致
# dropout_ratio : [0.1, 0.1, 0.1, 0.1]
}
}
item_tower {
id: "adgroup_id"
dnn {
hidden_units: [ 128, 32]
hidden_units: [ 128, 32] # 输出维度需要保证和user_augment特征组的embedding维度一致
}
}
simi_func: COSINE
temperature: 0.01
l2_regularization: 1e-6
amm_i_weight: 0.5 # AMM损失权重
amm_u_weight: 0.5
}
```

Expand Down
18 changes: 10 additions & 8 deletions easy_rec/python/model/match_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,16 @@ def _build_list_wise_loss_graph(self):
k in self._prediction_dict.keys() for k in
['augmented_p_u', 'augmented_p_i', 'augmented_a_u', 'augmented_a_i']
]):
self._loss_dict['amm_loss_u'] = tf.reduce_mean(
tf.square(self._prediction_dict['augmented_a_u'] -
self._prediction_dict['augmented_p_i'][:batch_size]) *
sample_weights) / tf.reduce_mean(sample_weights)
self._loss_dict['amm_loss_i'] = tf.reduce_mean(
tf.square(self._prediction_dict['augmented_a_i'][:batch_size] -
self._prediction_dict['augmented_p_u']) *
sample_weights) / tf.reduce_mean(sample_weights)
self._loss_dict[
'amm_loss_u'] = self._model_config.amm_u_weight * tf.reduce_mean(
tf.square(self._prediction_dict['augmented_a_u'] -
self._prediction_dict['augmented_p_i'][:batch_size]) *
sample_weights) / tf.reduce_mean(sample_weights)
self._loss_dict[
'amm_loss_i'] = self._model_config.amm_i_weight * tf.reduce_mean(
tf.square(self._prediction_dict['augmented_a_i'][:batch_size] -
self._prediction_dict['augmented_p_u']) *
sample_weights) / tf.reduce_mean(sample_weights)

else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))
Expand Down
4 changes: 4 additions & 0 deletions easy_rec/python/protos/dat.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ message DAT {
optional Similarity simi_func = 4 [default=COSINE];
required bool ignore_in_batch_neg_sam = 5 [default = false];
optional float temperature = 6 [default = 1.0];
// loss weight for amm_i
required float amm_i_weight = 7 [default = 0.5];
// loss weight for amm_u
required float amm_u_weight = 8 [default = 0.5];
}
1 change: 1 addition & 0 deletions easy_rec/python/tools/faiss_index_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import faiss
import numpy as np
import tensorflow as tf

from easy_rec.python.utils import io_util

logging.basicConfig(
Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/tools/hit_rate_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function

import sys

import tensorflow as tf

from easy_rec.python.utils import io_util
Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/tools/split_model_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver

from easy_rec.python.utils import io_util

if tf.__version__ >= '2.0':
Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/tools/split_pdn_model_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tensorflow.python.saved_model.utils_impl import get_variables_path
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver

from easy_rec.python.utils import io_util

FLAGS = tf.app.flags.FLAGS
Expand Down
2 changes: 1 addition & 1 deletion scripts/build_docker_tf112.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ then
exit 1
fi

sudo docker build --network=host . -f docker/Dockerfile_tf112 -t mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py27-tf1.12-${version}
sudo docker build --network=host . -f docker/Dockerfile_tf112 -t mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py27-tf1.12-${version}
2 changes: 1 addition & 1 deletion scripts/build_docker_tf115.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ then
exit 1
fi

sudo docker build --network=host . -f docker/Dockerfile_tf115 -t mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-${version}
sudo docker build --network=host . -f docker/Dockerfile_tf115 -t mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py36-tf1.15-${version}
2 changes: 1 addition & 1 deletion scripts/build_docker_tf212.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ then
exit 1
fi

sudo docker build --network=host . -f docker/Dockerfile_tf212 -t mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py38-tf2.12-${version}
sudo docker build --network=host . -f docker/Dockerfile_tf212 -t mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:py38-tf2.12-${version}

0 comments on commit a6da5c8

Please sign in to comment.