如何通过 tensorflow 的 tf.data API 加载泡菜文件

问题描述:

我的数据存储在磁盘上的多个 pickle 文件中.我想使用 tensorflow 的 tf.data.Dataset 将我的数据加载到训练管道中.我的代码如下:

I have my data in multiple pickle files stored on disk. I want to use tensorflow's tf.data.Dataset to load my data into training pipeline. My code goes:

def _parse_file(path):
    image, label = *load pickle file*
    return image, label
paths = glob.glob('*.pkl')
print(len(paths))
dataset = tf.data.Dataset.from_tensor_slices(paths)
dataset = dataset.map(_parse_file)
iterator = dataset.make_one_shot_iterator()

问题是我不知道如何实现 _parse_file 功能.此函数的参数 path 是张量类型.我试过了

Problem is I don't know how to implement the _parse_file fuction. The argument to this function, path, is of tensor type. I tried

def _parse_file(path):
    with tf.Session() as s:
        p = s.run(path)
        image, label = pickle.load(open(p, 'rb'))
    return image, label

并收到错误消息:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'arg0' with dtype string
     [[Node: arg0 = Placeholder[dtype=DT_STRING, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

在互联网上搜索了一些之后,我仍然不知道该怎么做.如果有人给我提示,我将不胜感激.

After some search on the Internet I still have no idea how to do it. I will be grateful to anyone providing me a hint.

我自己解决了这个问题.我应该像在这个 doc 中一样使用 tf.py_func.

I have solved this myself. I should use tf.py_func as in this doc.