Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature]split_model_pai.py supports both tf1 and tf2 #490

Merged
merged 3 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import os
import threading
import time
from distutils.version import LooseVersion

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import threading
import time
import unittest
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.platform import gfile

from easy_rec.python.main import predict
Expand Down
13 changes: 11 additions & 2 deletions easy_rec/python/tools/split_model_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework.dtypes import _TYPE_TO_STRING
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model.utils_impl import get_variables_path
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver

if tf.__version__ >= '2.0':
tf = tf.compat.v1
from tensorflow.python.saved_model.path_helpers import get_variables_path
from tensorflow.python.ops.resource_variable_ops import _from_proto_fn
else:
from tensorflow.python.saved_model.utils_impl import get_variables_path

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('model_dir', '', '')
tf.app.flags.DEFINE_string('user_model_dir', '', '')
Expand Down Expand Up @@ -198,7 +204,10 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
graph = ops.get_default_graph()
importer.import_graph_def(inference_graph, name='')
for name in variables_to_keep:
variable = graph.get_tensor_by_name(name)
if tf.__version__ >= '2.0':
variable = _from_proto_fn(variable_protos[name.split(':')[0]])
else:
variable = graph.get_tensor_by_name(name)
graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable)
saver = tf_saver.Saver()
saver.restore(sess, get_variables_path(model_dir))
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ multi_line_output = 7
force_single_line = true
known_standard_library = setuptools
known_first_party = easy_rec
known_third_party = absl,common_io,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
known_third_party = absl,common_io,distutils,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
no_lines_before = LOCALFOLDER
default_section = THIRDPARTY
skip = easy_rec/python/protos
Expand Down
Loading