如何在 tensorflow 上加载和使用保存的模型?

问题描述:

我发现了两种在 Tensorflow 中保存模型的方法:tf.train.Saver()SavedModelBuilder.但是,在以第二种方式加载后,我找不到有关使用该模型的文档.

I have found 2 ways to save a model in Tensorflow: tf.train.Saver() and SavedModelBuilder. However, I can't find documentation on using the model after it being loaded the second way.

注意:我想使用 SavedModelBuilder 方式,因为我用 Python 训练模型,并将在服务时以另一种语言(Go)使用它,而且似乎 SavedModelBuilder 在这种情况下是唯一的方法.

Note: I want to use SavedModelBuilder way because I train the model in Python and will use it at serving time in another language (Go), and it seems that SavedModelBuilder is the only way in that case.

这对 tf.train.Saver() 很有效(第一种方式):

This works great with tf.train.Saver() (first way):

model = tf.add(W * x, b, name="finalnode")

# save
saver = tf.train.Saver()
saver.save(sess, "/tmp/model")

# load
saver.restore(sess, "/tmp/model")

# IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT
# I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY.

model = graph.get_tensor_by_name("finalnode:0")
sess.run(model, {x: [5, 6, 7]})

tf.saved_model.builder.SavedModelBuilder() 定义在 Readme 但在使用 tf.saved_model.loader.load(sess, [], export_dir)) 加载模型后,我找不到有关获取的文档回到节点(参见上面代码中的finalnode")

tf.saved_model.builder.SavedModelBuilder() is defined in the Readme but after loading the model with tf.saved_model.loader.load(sess, [], export_dir)), I can't find documentation on getting back at the nodes (see "finalnode" in the code above)

缺少的是签名

# Saving
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= {
        "model": tf.saved_model.signature_def_utils.predict_signature_def(
            inputs= {"x": x},
            outputs= {"finalnode": model})
        })
builder.save()

# loading
with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, ["tag"], export_dir)
    graph = tf.get_default_graph()
    x = graph.get_tensor_by_name("x:0")
    model = graph.get_tensor_by_name("finalnode:0")
    print(sess.run(model, {x: [5, 6, 7, 8]}))