#-*- coding:utf-8 -*-
import os
import tensorflow as tf
import cv2
'''
文件目录为
chiwawa/
xx.jpg
xx.jpg
.....
japandog/
xx.jpg
xx.jpg
.....
'''
cwd = 'f:/py/tfrecord/'
classes={'chiwawa','japandog'} # 需要存入的标签,尽量与文件名一致,方便操作
sess = tf.Session()
writer = tf.python_io.TFRecordWriter("f:/py/tfrecord/train.tfrecords") # 建立一个writer
for index, name in enumerate(classes):
class_path = cwd + name + "/" # 构建文件路径
for img_name in os.listdir(class_path): # 遍历目录下的文件
img_path = class_path + img_name # 构建具体每一张图片的路径
image = cv2.imread(img_path) # 读取图片
# 获取图片的宽,高和通道数
img_w = image.shape[0]
img_h = image.shape[1]
img_c = image.shape[2]
# tf读取图片
img = tf.read_file(img_path)
img = tf.image.decode_jpeg(img)
# img = tf.image.resize_images(img,(224, 224)) 改变大小
img_raw = sess.run(tf.cast(img,tf.uint8)).tostring() #将图片转化为原生bytes
label = name.encode('utf-8') #将标签转化为bytes
'''
以下是Example类的常用固定格式,但要注意第一个features有s,对应的是tf.train.Features
tf.train.Features里的feature是没有s的,bytes_list对应的是tf.train.BytesList,
int64_list对应的是tf.train.Int64List,输入的value的格式也要一致,可输入的格式有int,float,bytes
label和img_raw的格式是bytes,宽、高、通道数的格式是int
'''
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'img_w': tf.train.Feature(int64_list=tf.train.Int64List(value=[img_w])),
'img_h': tf.train.Feature(int64_list=tf.train.Int64List(value=[img_h])),
'img_c': tf.train.Feature(int64_list=tf.train.Int64List(value=[img_c]))
}))
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()