Tensorflow模型的保存和加载

问题描述:

如何像我们在do keras中一样使用模型图保存张量流模型。
我们可以保存整个模型(权重和图)并稍后导入,而不是在预测文件中再次定义整个图

How can save a tensorflow model with model graph like we do in do keras. Instead of defining the whole graph again in prediction file, can we save whole model ( weight and graph) and import it later

在Keras中:

checkpoint = ModelCheckpoint('RightLane-{epoch:03d}.h5',monitor='val_loss', verbose=0,  save_best_only=False, mode='auto')

将给出一个h5可用于预测的文件

model = load_model("RightLane-030.h5")

如何在本机张量流中执行相同操作

方法1:在一个文件中冻结图形和权重(可能无法进行重新训练)



此选项显示如何将图形和权重保存在一个文件中。它的预期用例是在训练模型后部署/共享模型。为此,我们将使用protobuf(pb)格式。

Method 1: Freeze graph and weights in one file (retraining might not be possible)

This option shows how to save the graph and weights in one file. Its intended use case is for deploying/sharing a model after it has been trained. To this end, we will use the protobuf (pb) format.

鉴于tensorflow会话(和图形),您可以使用

Given a tensorflow session (and graph), you can generate a protobuf with

# freeze variables
output_graph_def = tf.graph_util.convert_variables_to_constants(
                               sess=sess,
                               input_graph_def =sess.graph.as_graph_def(),
                               output_node_names=['myMode/conv/output'])

# write protobuf to disk
with tf.gfile.GFile('graph.pb', "wb") as f:
    f.write(output_graph_def.SerializeToString())

其中 output_node_names 期望图形结果节点的名称字符串列表(请参见 tensorflow文档)。

where output_node_names expects a list of name strings for the result nodes of the graph (cf. tensorflow documentation).

然后,您可以加载protobuf并获得其权重为perfo的图形rm前进很容易。

Then, you can load the protobuf and get the graph with its weight to perform forward passes easily.

with tf.gfile.GFile(path_to_pb, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name='')
    return graph



方法2:恢复元数据和检查点(易于重新训练)



如果您希望能够继续训练模型,则可能需要恢复完整图形,即权重以及损失函数,一些梯度信息(对于Adam optimiser

Method 2: Restoring metagraph and checkpoint (easy retraining)

If you want to be able to continue training the model, you might need to restore the full graph, i.e. the weights but also the loss function, some gradient informations (for Adam optimiser for instance), etc.

您需要使用tensorflow生成的元文件和检查点文件

You need the meta and the checkpoint files generated by tensorflow when you use

saver = tf.train.Saver(...variables...)
saver.save(sess, 'my-model')

这将生成两个文件, my-model my -model.meta

This will generate two files, my-model and my-model.meta.

从这两个文件中,您可以加载图形wi th:

From these two files, you can load the graph with:

  new_saver = tf.train.import_meta_graph('my-model.meta')
  new_saver.restore(sess, 'my-model')

有关更多详细信息,请参见官方文档

For more details, you can look at the official documentation.