From b0ea0f3ff723c1bd9e4232e35b5e7640e4cb0b19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 3 Aug 2022 14:35:32 +0800 Subject: [PATCH] support split partition variables --- easy_rec/python/tools/split_model_pai.py | 13 +++++++------ setup.cfg | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/easy_rec/python/tools/split_model_pai.py b/easy_rec/python/tools/split_model_pai.py index d2a46ff32..624f7cf8a 100644 --- a/easy_rec/python/tools/split_model_pai.py +++ b/easy_rec/python/tools/split_model_pai.py @@ -65,7 +65,7 @@ def extract_sub_graph(graph_def, dest_nodes, variable_protos): node_seq = {} seq = 0 nodes_to_keep = set() - variables_to_keep = set() + variables_to_keep = dict() for node in graph_def.node: n = _node_name(node.name) @@ -81,11 +81,11 @@ def extract_sub_graph(graph_def, dest_nodes, variable_protos): n = next_to_visit[0] if n in variable_protos: - proto = variable_protos[n] + proto = variable_protos[n][0] next_to_visit.append(_node_name(proto.initial_value_name)) next_to_visit.append(_node_name(proto.initializer_name)) next_to_visit.append(_node_name(proto.snapshot_name)) - variables_to_keep.add(proto.variable_name) + variables_to_keep[proto.variable_name] = (proto, variable_protos[n][1]) del next_to_visit[0] if n in nodes_to_keep: @@ -137,7 +137,7 @@ def load_meta_graph_def(model_dir): tf.logging.info('%s' % proto.variable_name) variable_node_name = _node_name(proto.variable_name) if variable_node_name not in variable_protos: - variable_protos[variable_node_name] = proto + variable_protos[variable_node_name] = (proto, key) # parse signature info for SavedModel for sig_name in signatures: @@ -197,8 +197,9 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names, with sess.graph.as_default(): graph = ops.get_default_graph() importer.import_graph_def(inference_graph, name='') - for name in variables_to_keep: - variable = graph.get_tensor_by_name(name) + for name, (proto, key) in variables_to_keep.items(): + from_proto = ops.get_from_proto_function(key) + variable = from_proto(proto) graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable) saver = tf_saver.Saver() saver.restore(sess, get_variables_path(model_dir)) diff --git a/setup.cfg b/setup.cfg index 00e26bc28..e6d30fc7d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ multi_line_output = 7 force_single_line = true known_standard_library = setuptools known_first_party = easy_rec -known_third_party = absl,common_io,future,google,graphlearn,matplotlib,nni,numpy,odps,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml +known_third_party = absl,common_io,distutils,future,google,graphlearn,matplotlib,nni,numpy,odps,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml no_lines_before = LOCALFOLDER default_section = THIRDPARTY skip = easy_rec/python/protos