#字母预测:输入a预测出b,输入b预测出c,输入c预测出d,输入d预测出e,输入e预测出a
#10000 a
#01000 b
#00100 c
#00010 d
#00001 e
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense,SimpleRNN
import matplotlib.pyplot as plt
import os
input_word='abcde'
w_to_id={'a':0,'b':1,'c':2,'d':3,'e':4,}
id_to_onehot={0:[1.,0.,0.,0.,0.],1:[0.,1.,0.,0.,0.],2:[0.,0.,1.,0.,0.],3:[0.,0.,0.,1.,0.],4:[0.,0.,0.,0.,1.]}
x_train=[id_to_onehot[w_to_id['a']],id_to_onehot[w_to_id['b']],id_to_onehot[w_to_id['c']],id_to_onehot[w_to_id['d']],id_to_onehot[w_to_id['e']]]
y_train=[w_to_id['b'],w_to_id['c'],w_to_id['d'],w_to_id['e'],w_to_id['a']]
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
#使x_train符合SimpleRNN输入要求:[送入样本数,循环核时间展开步数,每个时间步输入特征个数]
#此处整个数据集送入,所以送入研样本数为len(x_train);输入1个字母出结果,循环核时间展开步数为1;表示独热码有5个输入特征,每个时间步输入特征个数为5
x_train=np.reshape(x_train,(len(x_train),1,5))
y_train=np.array(y_train)
model=tf.keras.Sequential([SimpleRNN(3),Dense(5,activation='softmax')])
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path='./checkpoint/rnn_onehot_lprel.ckpt'
if os.path.exists(checkpoint_save_path+'.index'):
print('-------------------load the model--------------')
model.load_weights(checkpoint_save_path)
cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,
save_best_only=True,
monitor='loss') #由于fit没有给出测试集,不计算测试集准确率,根据loss,保存最优模型
history=model.fit(x_train,y_train,batch_size=32,epochs=50,callbacks=[cp_callback])
model.summary()
# print(model.trainable_variables)
file = open('./rnn_weights.txt', 'w')
for v in model.trainable_variables:
file.write(str(v.name) + '
')
file.write(str(v.shape) + '
')
file.write(str(v.numpy()) + '
')
file.close()
############################################### show ###############################################
# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
loss = history.history['loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.title('Training Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.title('Training Loss')
plt.legend()
plt.show()
preNum=int(input('input the number of test alphabet'))
for i in range(preNum):
alphabet1=input('input test alphabet:')
alphabet=[id_to_onehot[w_to_id[alphabet1]]]
#使alphabet符合SimpleRNN输入要求[送入样本数,循环核时间展开步数,每个时间步输入特征个数]
#使此处验证效果送入了1个样本,送入样本数为1;输入1个字母出结果,所以循环核时间展开步数为1;独热码有5个输入特征,每个时间步输入特征个数为5
alphabet=np.reshape(alphabet,(1,1,5))
reseult=model.predict([alphabet])
pred=tf.argmax(reseult,axis=1)
pred=int(pred)
tf.print(alphabet1 + '->' + input_word[pred])