From 11b08e1ec97874d629976cedf764fcf6cca86364 Mon Sep 17 00:00:00 2001 From: "eric.gc" Date: Thu, 14 Nov 2024 17:44:45 +0800 Subject: [PATCH] DAT unit test --- README.md | 2 +- easy_rec/python/compat/early_stopping.py | 2 +- easy_rec/python/model/rank_model.py | 16 ++++++++++++---- easy_rec/python/test/train_eval_test.py | 8 +++++++- requirements/docs.txt | 2 +- samples/model_config/dat_on_taobao.config | 3 +-- setup.cfg | 2 +- 7 files changed, 24 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index ecab7df50..e011a1d01 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ If EasyRec is useful for your research, please cite: ### Join Us -- DingDing Group: 32260796. click [this url](https://page.dingtalk.com/wow/z/dingtalk/simple/ddhomedownload?action=joingroup&code=v1,k1,MwaiOIY1Tb2W+onmBBumO7sQsdDOYjBmv6FXC6wTGns=&_dt_no_comment=1&origin=11#/ ) or scan QrCode to join![dinggroup1.png](docs/images/qrcode/dinggroup1.png) +- DingDing Group: 32260796. click [this url](https://page.dingtalk.com/wow/z/dingtalk/simple/ddhomedownload?action=joingroup&code=v1,k1,MwaiOIY1Tb2W+onmBBumO7sQsdDOYjBmv6FXC6wTGns=&_dt_no_comment=1&origin=11#/) or scan QrCode to join![dinggroup1.png](docs/images/qrcode/dinggroup1.png) - DingDing Group2: 37930014162, click [this url](https://page.dingtalk.com/wow/z/dingtalk/simple/ddhomedownload?action=joingroup&code=v1,k1,1ppFWEXXNPyxUClHh77gCmpfB+JcPhbFv6FXC6wTGns=&_dt_no_comment=1&origin=11#/) or scan QrCode to join![dinggroup2.png](docs/images/qrcode/dinggroup2.png) - Email Group: easy_rec@service.aliyun.com. diff --git a/easy_rec/python/compat/early_stopping.py b/easy_rec/python/compat/early_stopping.py index fe4c12132..fc850fb62 100644 --- a/easy_rec/python/compat/early_stopping.py +++ b/easy_rec/python/compat/early_stopping.py @@ -21,9 +21,9 @@ import os import threading import time -from distutils.version import LooseVersion import tensorflow as tf +from distutils.version import LooseVersion from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops diff --git a/easy_rec/python/model/rank_model.py b/easy_rec/python/model/rank_model.py index 640f52502..dc3771daf 100644 --- a/easy_rec/python/model/rank_model.py +++ b/easy_rec/python/model/rank_model.py @@ -158,7 +158,9 @@ def build_rtp_output_dict(self): 'failed to build RTP rank_predict output: classification model ' + "expect 'probs' prediction, which is not found. Please check if" + ' build_predict_graph() is called.') - elif loss_types & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: + elif loss_types & { + LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS + }: if 'y' in self._prediction_dict: forwarded = self._prediction_dict['y'] else: @@ -379,7 +381,9 @@ def _build_metric_impl(self, metric.recall_at_topk.topk) elif metric.WhichOneof('metric') == 'mean_absolute_error': label = tf.to_float(self._labels[label_name]) - if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: + if loss_type & { + LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS + }: metric_dict['mean_absolute_error' + suffix] = metrics_tf.mean_absolute_error( label, self._prediction_dict['y' + suffix]) @@ -391,7 +395,9 @@ def _build_metric_impl(self, assert False, 'mean_absolute_error is not supported for this model' elif metric.WhichOneof('metric') == 'mean_squared_error': label = tf.to_float(self._labels[label_name]) - if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: + if loss_type & { + LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS + }: metric_dict['mean_squared_error' + suffix] = metrics_tf.mean_squared_error( label, self._prediction_dict['y' + suffix]) @@ -403,7 +409,9 @@ def _build_metric_impl(self, assert False, 'mean_squared_error is not supported for this model' elif metric.WhichOneof('metric') == 'root_mean_squared_error': label = tf.to_float(self._labels[label_name]) - if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: + if loss_type & { + LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS + }: metric_dict['root_mean_squared_error' + suffix] = metrics_tf.root_mean_squared_error( label, self._prediction_dict['y' + suffix]) diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 83656f2a0..a82455cff 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -7,11 +7,11 @@ import threading import time import unittest -from distutils.version import LooseVersion import numpy as np import six import tensorflow as tf +from distutils.version import LooseVersion from tensorflow.python.platform import gfile from easy_rec.python.main import predict @@ -1286,6 +1286,12 @@ def test_xdeefm_backbone_on_taobao(self): self._test_dir) self.assertTrue(self._success) + @unittest.skipIf(gl is None, 'graphlearn is not installed') + def test_dat_on_taobao(self): + self._success = test_utils.test_single_train_eval( + 'samples/model_config/dat_on_taobao.config', self._test_dir) + self.assertTrue(self._success) + if __name__ == '__main__': tf.test.main() diff --git a/requirements/docs.txt b/requirements/docs.txt index 596bd527b..9e81da2c6 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -5,4 +5,4 @@ recommonmark==0.6.0 sphinx==5.1.1 sphinx_markdown_tables==0.0.17 sphinx_rtd_theme -tensorflow-probability==0.11.0 \ No newline at end of file +tensorflow-probability==0.11.0 diff --git a/samples/model_config/dat_on_taobao.config b/samples/model_config/dat_on_taobao.config index e972149d7..2a7070557 100644 --- a/samples/model_config/dat_on_taobao.config +++ b/samples/model_config/dat_on_taobao.config @@ -19,7 +19,7 @@ train_config { } save_checkpoints_steps: 4000 sync_replicas: false - num_steps: 100000 + num_steps: 100 } eval_config { @@ -129,7 +129,6 @@ data_config { label_fields: 'clk' batch_size: 512 - num_epochs: 10000 prefetch_size: 32 input_type: CSVInput diff --git a/setup.cfg b/setup.cfg index d8ed85f21..f0223c47a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ multi_line_output = 7 force_single_line = true known_standard_library = setuptools known_first_party = easy_rec -known_third_party = absl,common_io,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,scipy,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,tensorflow_probability,yaml +known_third_party = absl,common_io,distutils,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,scipy,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,tensorflow_probability,yaml no_lines_before = LOCALFOLDER default_section = THIRDPARTY skip = easy_rec/python/protos