多线程加载多个 npz 文件

多线程加载多个 npz 文件

问题描述:

我有几个 .npz 文件.所有 .npz 文件都具有相同的结构:每个文件只包含两个变量,始终具有相同的变量名称.到目前为止,我只是简单地遍历所有 .npz 文件,检索两个变量值并将它们附加到某个全局变量中:

I have several .npz files. All .npz file the same structures: each of them just contain two variables, always with the same variable names. As of now, I simply loop over all .npz files, retrieve the two variable values and append them into some global variable:

# Let's assume there are 100 npz files
x_train = []
y_train = []
for npz_file_number in range(100):
    data = dict(np.load('{0:04d}.npz'.format(npz_file_number)))
    x_train.append(data['x'])
    y_train.append(data['y'])

需要一段时间,瓶颈是CPU.xy 变量附加到 x_trainy_train 变量的顺序无关紧要.

It takes a while, and the bottleneck is the CPU. The order in which x and y variables are appended to the x_train and y_train variables does not matter.

有没有办法多线程加载多个.npz文件?

Is there any way to load several .npz files in a multithreadedly?

我对 @Brent Washburne 的评论感到惊讶,并决定自己尝试一下.我认为一般问题有两个方面:

I was surprised by the comments of @Brent Washburne and decided to try it out myself. I think the general problem is two-fold:

首先,读取数据往往是IO绑定的,所以写多线程代码往往不会产生很高的性能提升.其次,由于语言本身的设计,在python中进行共享内存并行化本身就很困难.与原生 c 相比有更多的开销.

Firstly, reading data is often IO bound, so writing multi-threaded code often does not yield high performance gains. Secondly, doing shared memory parallelization in python is inherently difficult due to the design of the language itself. There's much more overhead compared to native c.

但让我们看看我们能做什么.

But let's see what we can do.

# some imports
import numpy as np
import glob
from multiprocessing import Pool
import os

# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)
    for i in range(100):
        x = np.random.rand(10000, 50)
        file_path = os.path.join(tmp_dir, '%05d.npz' % i)
        np.savez_compressed(file_path, x=x)

def read_x(path):
    with np.load(path) as data:
        return data["x"]

def serial_read(files):
    x_list = list(map(read_x, files))
    return x_list

def parallel_read(files):
    with Pool() as pool:
        x_list = pool.map(read_x, files)
    return x_list

好的,准备的东西够多了.让我们来看看时间.

Okay, enough stuff prepared. Let's get the timings.

files = glob.glob(os.path.join(tmp_dir, '*.npz'))

%timeit x_serial = serial_read(files)
# 1 loops, best of 3: 7.04 s per loop

%timeit x_parallel = parallel_read(files)
# 1 loops, best of 3: 3.56 s per loop

np.allclose(x_serial, x_parallel)
# True

它实际上看起来是一个不错的加速.我使用了两个真实内核和两个超线程内核.

It actually looks like a decent speedup. I am using two real and two hyper-threading cores.

要同时运行和计时,您可以执行此脚本:

To run and time everything at once, you can execute this script:

from __future__ import print_function
from __future__ import division

# some imports
import numpy as np
import glob
import sys
import multiprocessing
import os
import timeit

# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)
    for i in range(100):
        x = np.random.rand(10000, 50)
        file_path = os.path.join(tmp_dir, '%05d.npz' % i)
        np.savez_compressed(file_path, x=x)

def read_x(path):
    data = dict(np.load(path))
    return data['x']

def serial_read(files):
    x_list = list(map(read_x, files))
    return x_list

def parallel_read(files):
    pool = multiprocessing.Pool(processes=4)
    x_list = pool.map(read_x, files)
    return x_list


files = glob.glob(os.path.join(tmp_dir, '*.npz'))
#files = files[0:5] # to test on a subset of the npz files

# Timing:
timeit_runs = 5

timer = timeit.Timer(lambda: serial_read(files))
print('serial_read: {0:.4f} seconds averaged over {1} runs'
      .format(timer.timeit(number=timeit_runs) / timeit_runs,
      timeit_runs))
# 1 loops, best of 3: 7.04 s per loop

timer = timeit.Timer(lambda: parallel_read(files))
print('parallel_read: {0:.4f} seconds averaged over {1} runs'
      .format(timer.timeit(number=timeit_runs) / timeit_runs,
      timeit_runs))
# 1 loops, best of 3: 3.56 s per loop

# Examples of use:
x = serial_read(files)
print('len(x): {0}'.format(len(x))) # len(x): 100
print('len(x[0]): {0}'.format(len(x[0]))) # len(x[0]): 10000
print('len(x[0][0]): {0}'.format(len(x[0][0]))) # len(x[0]): 10000
print('x[0][0]: {0}'.format(x[0][0])) # len(x[0]): 10000
print('x[0].nbytes: {0} MB'.format(x[0].nbytes / 1e6)) # 4.0 MB