是否可以用现有图形中的常量替换占位符?

问题描述:

我有一个经过训练的模型的冻结图,它有一个 tf.placeholder,我总是向它提供相同的值.

I have a frozen graph of a trained model, it has one tf.placeholder which I always feed the same value to.

我想知道是否可以用 tf.constant 代替它.如果以某种方式 - 任何示例将不胜感激!

I was wondering if it's possible to replace it with tf.constant instead. If it is somehow - any examples would be appreciated!

以下是代码的外观,以帮助可视化问题

我正在使用(由其他人)预先训练的模型来运行推理.该模型在本地存储为带有 .pb 扩展名的冻结图形文件.

I am using a pre-trained (by other people) model to run inference. The model is stored locally as a frozen graph file with .pb extension.

代码如下:

# load graph
graph = load_graph('frozen.pb')
session = tf.Session(graph=graph)

# Get input and output tensors
images_placeholder = graph.get_tensor_by_name("input:0")
output = graph.get_tensor_by_name("output:0")
phase_train_placeholder = graph.get_tensor_by_name("phase_train:0")

feed_dict = {images_placeholder: images, phase_train_placeholder: False}

result = session.run(output, feed_dict=feed_dict)

问题是我总是出于我的目的提供 phase_train_placeholder: False,所以我想知道是否可以消除该占位符并将其替换为类似 tf.constant(False,dtype=bool, shape=[])

The problem is that I always feed phase_train_placeholder: False for my purposes, so I was wondering if it's possible to eliminate that placeholder and replace it with something like tf.constant(False, dtype=bool, shape=[])

所以我没有设法找到任何合适的方法,而是通过重建图形定义并替换我需要的节点以一种hacky 的方式设法做到了来代替.灵感来自这个代码.

So I didn't manage to find any proper way, but managed to do it in a hacky way, by rebuilding the graph def and substituting the node I needed to substitute. Inspired by this code.

这是代码(超级hacky,使用风险自负):

Here is the code (super hacky, use at your own risk):

INPUT_GRAPH_DEF_FILE = 'path/to/file'
OUTPUT_GRAPH_DEF_FILE = 'another/one'

# Get NodeDef of a constant tensor we want to put in place of 
# the placeholder. 
# (There is probably a better way to do this)
example_graph = tf.Graph()
with tf.Session(graph=example_graph):
    c = tf.constant(False, dtype=bool, shape=[], name='phase_train')
    for node in example_graph.as_graph_def().node:
        if node.name == 'phase_train':
            c_def = node

# load our graph
graph = load_graph(INPUT_GRAPH_DEF_FILE)
graph_def = graph.as_graph_def()

# Create new graph, and rebuild it from original one
# replacing phase train node def with constant
new_graph_def = graph_pb2.GraphDef()
for node in graph_def.node:
    if node.name == 'phase_train':
        new_graph_def.node.extend([c_def])
    else:
        new_graph_def.node.extend([copy.deepcopy(node)])

# save new graph
with tf.gfile.GFile(OUTPUT_GRAPH_DEF_FILE, "wb") as f:
    f.write(new_graph_def.SerializeToString())