Skip to content

Commit

Permalink
[feature]split_model_pai.py supports both tf1 and tf2 (#490)
Browse files Browse the repository at this point in the history
* split_model_pai.py supports TF2.x

* code style fixes

* get_variables_path fix
  • Loading branch information
eric-gecheng authored Oct 14, 2024
1 parent 488b3ab commit f4fa733
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import os
import threading
import time
from distutils.version import LooseVersion

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import threading
import time
import unittest
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.platform import gfile

from easy_rec.python.main import predict
Expand Down
13 changes: 11 additions & 2 deletions easy_rec/python/tools/split_model_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework.dtypes import _TYPE_TO_STRING
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model.utils_impl import get_variables_path
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver

if tf.__version__ >= '2.0':
tf = tf.compat.v1
from tensorflow.python.saved_model.path_helpers import get_variables_path
from tensorflow.python.ops.resource_variable_ops import _from_proto_fn
else:
from tensorflow.python.saved_model.utils_impl import get_variables_path

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('model_dir', '', '')
tf.app.flags.DEFINE_string('user_model_dir', '', '')
Expand Down Expand Up @@ -198,7 +204,10 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
graph = ops.get_default_graph()
importer.import_graph_def(inference_graph, name='')
for name in variables_to_keep:
variable = graph.get_tensor_by_name(name)
if tf.__version__ >= '2.0':
variable = _from_proto_fn(variable_protos[name.split(':')[0]])
else:
variable = graph.get_tensor_by_name(name)
graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable)
saver = tf_saver.Saver()
saver.restore(sess, get_variables_path(model_dir))
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ multi_line_output = 7
force_single_line = true
known_standard_library = setuptools
known_first_party = easy_rec
known_third_party = absl,common_io,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
known_third_party = absl,common_io,distutils,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
no_lines_before = LOCALFOLDER
default_section = THIRDPARTY
skip = easy_rec/python/protos
Expand Down

0 comments on commit f4fa733

Please sign in to comment.