diff --git a/easy_rec/python/tools/split_model_pai.py b/easy_rec/python/tools/split_model_pai.py index d2a46ff32..624f7cf8a 100644 --- a/easy_rec/python/tools/split_model_pai.py +++ b/easy_rec/python/tools/split_model_pai.py @@ -65,7 +65,7 @@ def extract_sub_graph(graph_def, dest_nodes, variable_protos): node_seq = {} seq = 0 nodes_to_keep = set() - variables_to_keep = set() + variables_to_keep = dict() for node in graph_def.node: n = _node_name(node.name) @@ -81,11 +81,11 @@ def extract_sub_graph(graph_def, dest_nodes, variable_protos): n = next_to_visit[0] if n in variable_protos: - proto = variable_protos[n] + proto = variable_protos[n][0] next_to_visit.append(_node_name(proto.initial_value_name)) next_to_visit.append(_node_name(proto.initializer_name)) next_to_visit.append(_node_name(proto.snapshot_name)) - variables_to_keep.add(proto.variable_name) + variables_to_keep[proto.variable_name] = (proto, variable_protos[n][1]) del next_to_visit[0] if n in nodes_to_keep: @@ -137,7 +137,7 @@ def load_meta_graph_def(model_dir): tf.logging.info('%s' % proto.variable_name) variable_node_name = _node_name(proto.variable_name) if variable_node_name not in variable_protos: - variable_protos[variable_node_name] = proto + variable_protos[variable_node_name] = (proto, key) # parse signature info for SavedModel for sig_name in signatures: @@ -197,8 +197,9 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names, with sess.graph.as_default(): graph = ops.get_default_graph() importer.import_graph_def(inference_graph, name='') - for name in variables_to_keep: - variable = graph.get_tensor_by_name(name) + for name, (proto, key) in variables_to_keep.items(): + from_proto = ops.get_from_proto_function(key) + variable = from_proto(proto) graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable) saver = tf_saver.Saver() saver.restore(sess, get_variables_path(model_dir))