From f4fa7330285e8aae2b1cd5bf94e298f0d944c9bc Mon Sep 17 00:00:00 2001 From: Eric Ge Date: Mon, 14 Oct 2024 16:06:54 +0800 Subject: [PATCH] [feature]split_model_pai.py supports both tf1 and tf2 (#490) * split_model_pai.py supports TF2.x * code style fixes * get_variables_path fix --- easy_rec/python/compat/early_stopping.py | 2 +- easy_rec/python/test/train_eval_test.py | 2 +- easy_rec/python/tools/split_model_pai.py | 13 +++++++++++-- setup.cfg | 2 +- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/easy_rec/python/compat/early_stopping.py b/easy_rec/python/compat/early_stopping.py index fe4c12132..fc850fb62 100644 --- a/easy_rec/python/compat/early_stopping.py +++ b/easy_rec/python/compat/early_stopping.py @@ -21,9 +21,9 @@ import os import threading import time -from distutils.version import LooseVersion import tensorflow as tf +from distutils.version import LooseVersion from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 0f7b82a28..ca29fc89c 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -7,11 +7,11 @@ import threading import time import unittest -from distutils.version import LooseVersion import numpy as np import six import tensorflow as tf +from distutils.version import LooseVersion from tensorflow.python.platform import gfile from easy_rec.python.main import predict diff --git a/easy_rec/python/tools/split_model_pai.py b/easy_rec/python/tools/split_model_pai.py index d2a46ff32..bdb2087de 100644 --- a/easy_rec/python/tools/split_model_pai.py +++ b/easy_rec/python/tools/split_model_pai.py @@ -9,10 +9,16 @@ from tensorflow.python.framework import ops from tensorflow.python.framework.dtypes import _TYPE_TO_STRING from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model.utils_impl import get_variables_path from tensorflow.python.tools import saved_model_utils from tensorflow.python.training import saver as tf_saver +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + from tensorflow.python.saved_model.path_helpers import get_variables_path + from tensorflow.python.ops.resource_variable_ops import _from_proto_fn +else: + from tensorflow.python.saved_model.utils_impl import get_variables_path + FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('model_dir', '', '') tf.app.flags.DEFINE_string('user_model_dir', '', '') @@ -198,7 +204,10 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names, graph = ops.get_default_graph() importer.import_graph_def(inference_graph, name='') for name in variables_to_keep: - variable = graph.get_tensor_by_name(name) + if tf.__version__ >= '2.0': + variable = _from_proto_fn(variable_protos[name.split(':')[0]]) + else: + variable = graph.get_tensor_by_name(name) graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable) saver = tf_saver.Saver() saver.restore(sess, get_variables_path(model_dir)) diff --git a/setup.cfg b/setup.cfg index 337833a0f..b43211827 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ multi_line_output = 7 force_single_line = true known_standard_library = setuptools known_first_party = easy_rec -known_third_party = absl,common_io,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml +known_third_party = absl,common_io,distutils,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml no_lines_before = LOCALFOLDER default_section = THIRDPARTY skip = easy_rec/python/protos