Skip to content

Commit

Permalink
DAT unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-gecheng committed Nov 14, 2024
1 parent 35681fc commit 11b08e1
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ If EasyRec is useful for your research, please cite:

### Join Us

- DingDing Group: 32260796. click [this url](https://page.dingtalk.com/wow/z/dingtalk/simple/ddhomedownload?action=joingroup&code=v1,k1,MwaiOIY1Tb2W+onmBBumO7sQsdDOYjBmv6FXC6wTGns=&_dt_no_comment=1&origin=11#/ ) or scan QrCode to join![dinggroup1.png](docs/images/qrcode/dinggroup1.png)
- DingDing Group: 32260796. click [this url](https://page.dingtalk.com/wow/z/dingtalk/simple/ddhomedownload?action=joingroup&code=v1,k1,MwaiOIY1Tb2W+onmBBumO7sQsdDOYjBmv6FXC6wTGns=&_dt_no_comment=1&origin=11#/) or scan QrCode to join![dinggroup1.png](docs/images/qrcode/dinggroup1.png)
- DingDing Group2: 37930014162, click [this url](https://page.dingtalk.com/wow/z/dingtalk/simple/ddhomedownload?action=joingroup&code=v1,k1,1ppFWEXXNPyxUClHh77gCmpfB+JcPhbFv6FXC6wTGns=&_dt_no_comment=1&origin=11#/) or scan QrCode to join![dinggroup2.png](docs/images/qrcode/dinggroup2.png)
- Email Group: [email protected].

Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import os
import threading
import time
from distutils.version import LooseVersion

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand Down
16 changes: 12 additions & 4 deletions easy_rec/python/model/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def build_rtp_output_dict(self):
'failed to build RTP rank_predict output: classification model ' +
"expect 'probs' prediction, which is not found. Please check if" +
' build_predict_graph() is called.')
elif loss_types & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
elif loss_types & {
LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
}:
if 'y' in self._prediction_dict:
forwarded = self._prediction_dict['y']
else:
Expand Down Expand Up @@ -379,7 +381,9 @@ def _build_metric_impl(self,
metric.recall_at_topk.topk)
elif metric.WhichOneof('metric') == 'mean_absolute_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
if loss_type & {
LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
}:
metric_dict['mean_absolute_error' +
suffix] = metrics_tf.mean_absolute_error(
label, self._prediction_dict['y' + suffix])
Expand All @@ -391,7 +395,9 @@ def _build_metric_impl(self,
assert False, 'mean_absolute_error is not supported for this model'
elif metric.WhichOneof('metric') == 'mean_squared_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
if loss_type & {
LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
}:
metric_dict['mean_squared_error' +
suffix] = metrics_tf.mean_squared_error(
label, self._prediction_dict['y' + suffix])
Expand All @@ -403,7 +409,9 @@ def _build_metric_impl(self,
assert False, 'mean_squared_error is not supported for this model'
elif metric.WhichOneof('metric') == 'root_mean_squared_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
if loss_type & {
LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
}:
metric_dict['root_mean_squared_error' +
suffix] = metrics_tf.root_mean_squared_error(
label, self._prediction_dict['y' + suffix])
Expand Down
8 changes: 7 additions & 1 deletion easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import threading
import time
import unittest
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.platform import gfile

from easy_rec.python.main import predict
Expand Down Expand Up @@ -1286,6 +1286,12 @@ def test_xdeefm_backbone_on_taobao(self):
self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_dat_on_taobao(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dat_on_taobao.config', self._test_dir)
self.assertTrue(self._success)


if __name__ == '__main__':
tf.test.main()
2 changes: 1 addition & 1 deletion requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ recommonmark==0.6.0
sphinx==5.1.1
sphinx_markdown_tables==0.0.17
sphinx_rtd_theme
tensorflow-probability==0.11.0
tensorflow-probability==0.11.0
3 changes: 1 addition & 2 deletions samples/model_config/dat_on_taobao.config
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ train_config {
}
save_checkpoints_steps: 4000
sync_replicas: false
num_steps: 100000
num_steps: 100
}

eval_config {
Expand Down Expand Up @@ -129,7 +129,6 @@ data_config {

label_fields: 'clk'
batch_size: 512
num_epochs: 10000
prefetch_size: 32
input_type: CSVInput

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ multi_line_output = 7
force_single_line = true
known_standard_library = setuptools
known_first_party = easy_rec
known_third_party = absl,common_io,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,scipy,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,tensorflow_probability,yaml
known_third_party = absl,common_io,distutils,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,scipy,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,tensorflow_probability,yaml
no_lines_before = LOCALFOLDER
default_section = THIRDPARTY
skip = easy_rec/python/protos
Expand Down

0 comments on commit 11b08e1

Please sign in to comment.