From 5cc4e8fe1e476d69834cda92705134fee2482aa1 Mon Sep 17 00:00:00 2001 From: "eric.gc" Date: Sat, 12 Oct 2024 15:17:43 +0800 Subject: [PATCH 1/3] split_model_pai.py supports TF2.x --- easy_rec/python/tools/split_model_pai.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/easy_rec/python/tools/split_model_pai.py b/easy_rec/python/tools/split_model_pai.py index d2a46ff32..3934f6c77 100644 --- a/easy_rec/python/tools/split_model_pai.py +++ b/easy_rec/python/tools/split_model_pai.py @@ -9,10 +9,16 @@ from tensorflow.python.framework import ops from tensorflow.python.framework.dtypes import _TYPE_TO_STRING from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model.utils_impl import get_variables_path from tensorflow.python.tools import saved_model_utils from tensorflow.python.training import saver as tf_saver +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + from tensorflow.python.saved_model.path_helpers import get_variables_path + from tensorflow.python.ops.resource_variable_ops import _from_proto_fn +else: + from tensorflow.compat.v1.python.saved_model.utils_impl import get_variables_path + FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('model_dir', '', '') tf.app.flags.DEFINE_string('user_model_dir', '', '') @@ -198,7 +204,10 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names, graph = ops.get_default_graph() importer.import_graph_def(inference_graph, name='') for name in variables_to_keep: - variable = graph.get_tensor_by_name(name) + if tf.__version__ >= '2.0': + variable = _from_proto_fn(variable_protos[name.split(':')[0]]) + else: + variable = graph.get_tensor_by_name(name) graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable) saver = tf_saver.Saver() saver.restore(sess, get_variables_path(model_dir)) From 6483f7a99a5cda7f32805abd5d7ae88adc14d93c Mon Sep 17 00:00:00 2001 From: "eric.gc" Date: Sat, 12 Oct 2024 15:18:48 +0800 Subject: [PATCH 2/3] code style fixes --- easy_rec/python/compat/early_stopping.py | 2 +- easy_rec/python/test/train_eval_test.py | 2 +- setup.cfg | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/easy_rec/python/compat/early_stopping.py b/easy_rec/python/compat/early_stopping.py index fe4c12132..fc850fb62 100644 --- a/easy_rec/python/compat/early_stopping.py +++ b/easy_rec/python/compat/early_stopping.py @@ -21,9 +21,9 @@ import os import threading import time -from distutils.version import LooseVersion import tensorflow as tf +from distutils.version import LooseVersion from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 0f7b82a28..ca29fc89c 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -7,11 +7,11 @@ import threading import time import unittest -from distutils.version import LooseVersion import numpy as np import six import tensorflow as tf +from distutils.version import LooseVersion from tensorflow.python.platform import gfile from easy_rec.python.main import predict diff --git a/setup.cfg b/setup.cfg index 337833a0f..b43211827 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ multi_line_output = 7 force_single_line = true known_standard_library = setuptools known_first_party = easy_rec -known_third_party = absl,common_io,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml +known_third_party = absl,common_io,distutils,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml no_lines_before = LOCALFOLDER default_section = THIRDPARTY skip = easy_rec/python/protos From 8df7de29f38784bd70f5054b533d81e6c65b4203 Mon Sep 17 00:00:00 2001 From: "eric.gc" Date: Mon, 14 Oct 2024 14:26:00 +0800 Subject: [PATCH 3/3] get_variables_path fix --- easy_rec/python/tools/split_model_pai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/easy_rec/python/tools/split_model_pai.py b/easy_rec/python/tools/split_model_pai.py index 3934f6c77..bdb2087de 100644 --- a/easy_rec/python/tools/split_model_pai.py +++ b/easy_rec/python/tools/split_model_pai.py @@ -17,7 +17,7 @@ from tensorflow.python.saved_model.path_helpers import get_variables_path from tensorflow.python.ops.resource_variable_ops import _from_proto_fn else: - from tensorflow.compat.v1.python.saved_model.utils_impl import get_variables_path + from tensorflow.python.saved_model.utils_impl import get_variables_path FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('model_dir', '', '')