训练后如何保存/恢复模型?
在 Tensorflow 中训练模型后:
After you train a model in Tensorflow:
- 如何保存训练好的模型?
- 您以后如何恢复这个保存的模型?
Tensorflow 2 文档
保存检查点
改编自文档
# -------------------------
# ----- Toy Context -----
# -------------------------
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
def toy_dataset():
inputs = tf.range(10.0)[:, None]
labels = inputs * 5.0 + tf.range(5.0)[None, :]
return (
tf.data.Dataset.from_tensor_slices(dict(x=inputs, y=labels)).repeat().batch(2)
)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example["x"])
loss = tf.reduce_mean(tf.abs(output - example["y"]))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
# ----------------------------
# ----- Create Objects -----
# ----------------------------
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)
# ----------------------------
# ----- Train and Save -----
# ----------------------------
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
# ---------------------
# ----- Restore -----
# ---------------------
# In another script, re-initialize objects
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)
# Re-use the manager code above ^
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
# Continue training or evaluate etc.
更多链接
-
关于
saved_model
的详尽而有用的教程->https://www.tensorflow.org/guide/saved_modelkeras
保存模型的详细指南 ->https://www.tensorflow.org/guide/keras/save_and_serialize检查点捕获模型使用的所有参数(tf.Variable 对象)的确切值.检查点不包含对模型定义的计算的任何描述,因此通常仅在将使用保存的参数值的源代码可用时才有用.
Checkpoints capture the exact value of all parameters (tf.Variable objects) used by a model. Checkpoints do not contain any description of the computation defined by the model and thus are typically only useful when source code that will use the saved parameter values is available.
另一方面,SavedModel 格式除了参数值(检查点)之外,还包括模型定义的计算的序列化描述.这种格式的模型独立于创建模型的源代码.因此,它们适合通过 TensorFlow Serving、TensorFlow Lite、TensorFlow.js 或其他编程语言(C、C++、Java、Go、Rust、C# 等 TensorFlow API)的程序进行部署.
The SavedModel format on the other hand includes a serialized description of the computation defined by the model in addition to the parameter values (checkpoint). Models in this format are independent of the source code that created the model. They are thus suitable for deployment via TensorFlow Serving, TensorFlow Lite, TensorFlow.js, or programs in other programming languages (the C, C++, Java, Go, Rust, C# etc. TensorFlow APIs).
(亮点是我自己的)
来自文档:
# Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) inc_v1 = v1.assign(v1+1) dec_v2 = v2.assign(v2-1) # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, and save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() # Save the variables to disk. save_path = saver.save(sess, "/tmp/model.ckpt") print("Model saved in path: %s" % save_path)
恢复
tf.reset_default_graph() # Create some variables. v1 = tf.get_variable("v1", shape=[3]) v2 = tf.get_variable("v2", shape=[5]) # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, use the saver to restore variables from disk, and # do some work with the model. with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "/tmp/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval())
simple_save
很多好的答案,为了完整起见,我会加上我的 2 美分:simple_save.也是一个使用
tf.data.Dataset
API 的独立代码示例.simple_save
Many good answer, for completeness I'll add my 2 cents: simple_save. Also a standalone code example using the
tf.data.Dataset
API.Python 3 ;Tensorflow 1.14
Python 3 ; Tensorflow 1.14
import tensorflow as tf from tensorflow.saved_model import tag_constants with tf.Graph().as_default(): with tf.Session() as sess: ... # Saving inputs = { "batch_size_placeholder": batch_size_placeholder, "features_placeholder": features_placeholder, "labels_placeholder": labels_placeholder, } outputs = {"prediction": model_output} tf.saved_model.simple_save( sess, 'path/to/your/location/', inputs, outputs )
恢复:
graph = tf.Graph() with restored_graph.as_default(): with tf.Session() as sess: tf.saved_model.loader.load( sess, [tag_constants.SERVING], 'path/to/your/location/', ) batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0') features_placeholder = graph.get_tensor_by_name('features_placeholder:0') labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0') prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0') sess.run(prediction, feed_dict={ batch_size_placeholder: some_value, features_placeholder: some_other_value, labels_placeholder: another_value })
独立示例
为了演示,以下代码生成随机数据.
The following code generates random data for the sake of the demonstration.
- 我们首先创建占位符.它们将在运行时保存数据.从它们中,我们创建了
Dataset
,然后是它的Iterator
.我们得到迭代器生成的张量,称为input_tensor
,它将作为我们模型的输入. - 模型本身是从
input_tensor
构建的:一个基于 GRU 的双向 RNN,后跟一个密集分类器.因为为什么不呢. - 损失是一个
softmax_cross_entropy_with_logits
,用Adam
优化.在 2 个 epochs(每个 2 个批次)之后,我们保存了训练过的"带有tf.saved_model.simple_save
的模型.如果按原样运行代码,则模型将保存在当前工作目录中名为simple/
的文件夹中. - 在一个新图中,我们然后使用
tf.saved_model.loader.load
恢复保存的模型.我们使用graph.get_tensor_by_name
获取占位符和 logits,使用graph.get_operation_by_name
获取Iterator
初始化操作. - 最后,我们对数据集中的两个批次进行推理,并检查保存和恢复的模型是否产生相同的值.他们有!
- We start by creating the placeholders. They will hold the data at runtime. From them, we create the
Dataset
and then itsIterator
. We get the iterator's generated tensor, calledinput_tensor
which will serve as input to our model. - The model itself is built from
input_tensor
: a GRU-based bidirectional RNN followed by a dense classifier. Because why not. - The loss is a
softmax_cross_entropy_with_logits
, optimized withAdam
. After 2 epochs (of 2 batches each), we save the "trained" model withtf.saved_model.simple_save
. If you run the code as is, then the model will be saved in a folder calledsimple/
in your current working directory. - In a new graph, we then restore the saved model with
tf.saved_model.loader.load
. We grab the placeholders and logits withgraph.get_tensor_by_name
and theIterator
initializing operation withgraph.get_operation_by_name
. - Lastly we run an inference for both batches in the dataset, and check that the saved and restored model both yield the same values. They do!
代码:
import os import shutil import numpy as np import tensorflow as tf from tensorflow.python.saved_model import tag_constants def model(graph, input_tensor): """Create the model which consists of a bidirectional rnn (GRU(10)) followed by a dense classifier Args: graph (tf.Graph): Tensors' graph input_tensor (tf.Tensor): Tensor fed as input to the model Returns: tf.Tensor: the model's output layer Tensor """ cell = tf.nn.rnn_cell.GRUCell(10) with graph.as_default(): ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn( cell_fw=cell, cell_bw=cell, inputs=input_tensor, sequence_length=[10] * 32, dtype=tf.float32, swap_memory=True, scope=None) outputs = tf.concat((fw_outputs, bw_outputs), 2) mean = tf.reduce_mean(outputs, axis=1) dense = tf.layers.dense(mean, 5, activation=None) return dense def get_opt_op(graph, logits, labels_tensor): """Create optimization operation from model's logits and labels Args: graph (tf.Graph): Tensors' graph logits (tf.Tensor): The model's output without activation labels_tensor (tf.Tensor): Target labels Returns: tf.Operation: the operation performing a stem of Adam optimizer """ with graph.as_default(): with tf.variable_scope('loss'): loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( logits=logits, labels=labels_tensor, name='xent'), name="mean-xent" ) with tf.variable_scope('optimizer'): opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss) return opt_op if __name__ == '__main__': # Set random seed for reproducibility # and create synthetic data np.random.seed(0) features = np.random.randn(64, 10, 30) labels = np.eye(5)[np.random.randint(0, 5, (64,))] graph1 = tf.Graph() with graph1.as_default(): # Random seed for reproducibility tf.set_random_seed(0) # Placeholders batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph') features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph') labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph') # Dataset dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph)) dataset = dataset.batch(batch_size_ph) iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') input_tensor, labels_tensor = iterator.get_next() # Model logits = model(graph1, input_tensor) # Optimization opt_op = get_opt_op(graph1, logits, labels_tensor) with tf.Session(graph=graph1) as sess: # Initialize variables tf.global_variables_initializer().run(session=sess) for epoch in range(3): batch = 0 # Initialize dataset (could feed epochs in Dataset.repeat(epochs)) sess.run( dataset_init_op, feed_dict={ features_data_ph: features, labels_data_ph: labels, batch_size_ph: 32 }) values = [] while True: try: if epoch < 2: # Training _, value = sess.run([opt_op, logits]) print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0])) batch += 1 else: # Final inference values.append(sess.run(logits)) print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0])) batch += 1 except tf.errors.OutOfRangeError: break # Save model state print('\nSaving...') cwd = os.getcwd() path = os.path.join(cwd, 'simple') shutil.rmtree(path, ignore_errors=True) inputs_dict = { "batch_size_ph": batch_size_ph, "features_data_ph": features_data_ph, "labels_data_ph": labels_data_ph } outputs_dict = { "logits": logits } tf.saved_model.simple_save( sess, path, inputs_dict, outputs_dict ) print('Ok') # Restoring graph2 = tf.Graph() with graph2.as_default(): with tf.Session(graph=graph2) as sess: # Restore saved values print('\nRestoring...') tf.saved_model.loader.load( sess, [tag_constants.SERVING], path ) print('Ok') # Get restored placeholders labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0') features_data_ph = graph2.get_tensor_by_name('features_data_ph:0') batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0') # Get restored model output restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0') # Get dataset initializing operation dataset_init_op = graph2.get_operation_by_name('dataset_init') # Initialize restored dataset sess.run( dataset_init_op, feed_dict={ features_data_ph: features, labels_data_ph: labels, batch_size_ph: 32 } ) # Compute inference for both batches in dataset restored_values = [] for i in range(2): restored_values.append(sess.run(restored_logits)) print('Restored values: ', restored_values[i][0]) # Check if original inference and restored inference are equal valid = all((v == rv).all() for v, rv in zip(values, restored_values)) print('\nInferences match: ', valid)
这将打印:
$ python3 save_and_restore.py Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595 0.12804556 0.20013677 -0.08229901] Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045 -0.00107776] Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792 -0.00602257 0.07465433 0.11674127] Epoch 1, batch 1 | Sample value: [-0.05275984 0.05981954 -0.15913513 -0.3244143 0.10673307] Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553 -0.04276478 0.2933622 ] Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117 0.11119192 -0.20817074 -0.35660955 0.16990358] Saving... INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb' Ok Restoring... INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables' Ok Restored values: [-0.26331693 -0.13013336 -0.12553 -0.04276478 0.2933622 ] Restored values: [-0.07730117 0.11119192 -0.20817074 -0.35660955 0.16990358] Inferences match: True
- 我们首先创建占位符.它们将在运行时保存数据.从它们中,我们创建了