带你从零掌握迭代器及构建最简DataLoader

0 摘要

    本文本意是写 pytorch 中 DataLoader 源码学习心得,但是发现自己对迭代器和生成器的掌握比较水,不够牢固,而我也没有搜到能够解决我所有疑问的解答文章,因此诞生了这篇文章。通过本文你将能够零基础深入掌握 python 迭代器相关知识、并且能够一步步理解 DataLoader 的实现原理以及背后涉及的设计模式

    本文最终目的是通过源码学习自己实现一个功能比较完善的 DataLoader 类,为了达到这个目的,本文写作流程是:

  • 先深入浅出分析 python 中迭代器、生成器等实现原理,包括 Iterable、Iterator、for .. in ..、__getitem__、yield 生成器 5个部分

  • 再实现了一个最简单版本的 DataLoader,目的是理解 DataLoader 与 Dataset、Sampler、BatchSampler和 collate_fn 之间的调用关系

  • 最后对该实现进行深入全面分析,读者可以清晰的理解每个类的作用

    但是 DataLoader 功能其实非常复杂,故本文属于系列文章的第一篇,后面文章会不断完善、调整,最终实现 DataLoader 所有功能。或者说本文是后续文章的基础,如果基础内容没有理解非常透彻,后面的多进程、分布式版本就更难以理解了。

    虽然本文比较简单,但是由于涉及到代码,故为了方便,有必要的读者可以 clone rep 进行学习(需要特意说明的是:rep 里面代码是学习目的的,质量不高,不要要求那么多)

github:  https://github.com/hhaAndroid/miniloader

由于本人水平有限,某些环节理解可能有偏颇,欢迎指正。手机对于代码显示效果不太好,建议电脑端阅读。

 

1 python 迭代器深入浅出理解

 

1.1 可迭代对象 Iterable

    可迭代对象 Iterable:表示该对象可迭代,其并不是指某种具体数据类型。简单来说只要是实现了 `__iter__` 方法的类就是可迭代对象

from collections.abc import Iterable, Iteratorclass A(object):    def __init__(self):        self.a = [1, 2, 3]    def __iter__(self):        # 此处返回啥无所谓        return self.acls_a = A()#  Trueprint(isinstance(cls_a, Iterable))

    但是对象如果是 Iterable 的,看起来好像也没有特别大的用途,因为你依然无法迭代,实际上 Iterable 仅仅是提供了一种抽象规范接口:

for a in cls_a:    print(a)
# 程序报错,要理解这个错误的含义TypeError: iter() returned non-iterator of type 'list'

    我们可以检查下 Iterable 接口:

class Iterable(metaclass=ABCMeta):
# 如果实现了这个方法,那么就是 Iterable @abstractmethod def __iter__(self): while False: yield None
@classmethod def __subclasshook__(cls, C): if cls is Iterable: return _check_methods(C, "__iter__") return NotImplemented

看起来实现 Iterable 接口用途不大,其实不是的,其有很多用途的,例如简化代码等,在后面的高级语法糖中会频繁用到,后面会分析。

1.2 迭代器 Iterator

    迭代器 Iterator:其和 Iterable 之间是一个包含与被包含的关系,如果一个对象是迭代器 Iterator,那么这个对象肯定是可迭代 Iterable;但是反过来,如果一个对象是可迭代 Iterable,那么这个对象不一定是迭代器 Iterator,可以通过接口协议看出:

class Iterator(Iterable):
# 迭代具体实现 @abstractmethod def __next__(self): 'Return the next item from the iterator. When exhausted, raise StopIteration' raise StopIteration
    # 返回自身,因为自身有 __next__ 方法(如果自身没有 __next__,那么返回自身没有意义) def __iter__(self):        return self         @classmethod def __subclasshook__(cls, C): if cls is Iterator: return _check_methods(C, '__iter__', '__next__') return NotImplemented

可以发现:实现了 `__next__` 和 `__iter__` 方法的类才能称为迭代器,就可以被 for 遍历了

class A(object):    def __init__(self):        self.index = -1        self.a = [1, 2, 3]
    #必须要返回一个实现了 __next__ 方法的对象,否则后面无法 for 遍历    #因为本类自身实现了 __next__,所以通常都是返回 self 对象即可 def __iter__(self): return self
def __next__(self): self.index += 1 if self.index < len(self.a): return self.a[self.index] else:            #抛异常,for 内部会自动捕获,表示迭代完成 raise StopIteration("遍历完了")cls_a = A()print(isinstance(cls_a, Iterable)) # Trueprint(isinstance(cls_a, Iterator)) # Trueprint(isinstance(iter(cls_a), Iterator)) # True
for a in cls_a: print(a)# 打印 1 2 3

再次明确,一个对象如果要是 Iterator ,那么必须要实现 `__next__` 和 `__iter__` 方法,但是要理解其内部迭代流程,还需要理解 for .. in .. 流程。

1.3 for .. in .. 本质流程

    for .. in .. 也就是常见的迭代操作了,其被 python 编译器编译后,实际上代码是:

# 实际调用了 __iter__ 方法返回自身,包括了 __next__ 方法的对象cls_a = iter(cls_a)while True:    try:        # 然后调用对象的 __next__ 方法,不断返回元素        value = next(cls_a)        print(value)    # 如果迭代完成,则捕获异常即可    except StopIteration:        break

可以看出,任何一个对象如果要能够被 for 遍历,必须要实现  `__iter__` 和 `__next__` 方法,缺一不可

    明白了上述流程,那么迭代器对象 A,我们可以采用如下方式进行遍历:

myiter = iter(cls_a)print(next(myiter))print(next(myiter))print(next(myiter))# 因为遍历完了,故此时会出现错误:StopIteration: 遍历完了print(next(myiter))

我们再来思考 python 内置对象 list 为啥可以被迭代

b=list([1,2,3])print(isinstance(b, Iterable)) # Trueprint(isinstance(b, Iterator)) # False

    可以发现 list 类型是可迭代对象,但是其不是迭代器(即 list 没有 `__next__` 方法),那为啥 for .. in .. 可以迭代呢?

    原因是 list 内部的 `__iter__` 方法内部返回了具备 `__next__` 方法的类,或者说调用 iter() 后返回的对象本身就是一个迭代器,当然可以 for 循环了

b=list([1,2,3])print(dir(b)) # 可以发现其存在 __iter__ 方法,不存在 __next__
b=iter(b) # 调用 list 内部的 __iter__,返回了具备 __next__ 的对象print(isinstance(b, Iterable)) # Trueprint(isinstance(b, Iterator)) # Trueprint(dir(b)) # 同时具备 __iter__ 和 __next__ 方法

基于上述理解我们可以对 A 类代码进行改造,使其更加简单:

class A(object):    def __init__(self):        self.a = [1, 2, 3]    # 我们内部又调用了 list 对象的 __iter__ 方法,故此时返回的对象是迭代器对象    def __iter__(self):        return iter(self.a)
cls_a = A()print(isinstance(cls_a, Iterable)) # Trueprint(isinstance(cls_a, Iterator)) # False
for a in cls_a: print(a)# 输出:1 2 3

    此时我们就实现了仅仅实现 Iterable 规范接口,但是又具备了 for .. in .. 功能,代码是不是比最开始的实现简单很多?这种写法应用也非常广泛,因为其不需要自己再次实现 `__next__` 方法。

    如果你想理解的更加透彻,那么可以看下面例子:

# 仅仅实现 __iter__ class A(object):    def __init__(self):        self.b = B()
def __iter__(self): return self.b
# 仅仅实现 __next__class B(object): def __init__(self): self.index = -1 self.a = [1, 2, 3]
def __next__(self): self.index += 1 if self.index < len(self.a): return self.a[self.index] else: # 内部会自动捕获,表示迭代完成 raise StopIteration("遍历完了")

cls_a = A()cls_b = B()print(isinstance(cls_a, Iterable)) # Trueprint(isinstance(cls_a, Iterator)) # Falseprint(isinstance(cls_b, Iterable)) # Falseprint(isinstance(cls_b, Iterator)) # False
print(type(iter(cls_a))) # B 对象print(isinstance(iter(cls_a), Iterator)) # False
for a in cls_a: print(a)
# 输出:1 2 3

    自此我们知道了:一个对象要能够被 for .. in .. 迭代,那么不管你是直接实现 `__iter__` 和 `__next__` 方法(对象必然是 Iterator),还是只实现 `__iter__`(不是 Iterator),但是内部间接返回了具备 `__next__` 对象的类,都是可行的

    但是除了这两种实现,还有其他高级语法糖,可以进一步精简代码。

1.4  __ getitem__ 理解

    上面说过 for .. in .. 的本质就是调用对象的 `__iter__` 和 `__next__` 方法,但是有一种更加简单的写法,你通过仅仅实现 `__getitem__` 方法就可以让对象实现迭代功能。实际上任何一个类,如果实现了`__getitem__` 方法,那么当调用 iter(类实例) 时候会自动具备`__iter__` 和 `__next__`方法,从而可迭代了。

    通过下面例子可以看出,`__getitem__` 实际上是属于 __iter__` 和 `__next__` 方法的高级封装,也就是我们常说的语法糖,只不过这个转化是通过编译器完成,内部自动转化,非常方便。

class A(object):    def __init__(self):        self.a = [1, 2, 3]
def __getitem__(self, item):        return self.a[item]        cls_a = A()print(isinstance(cls_a, Iterable)) # Falseprint(isinstance(cls_a, Iterator)) # Falseprint(dir(cls_a)) # 仅仅具备 __getitem__ 方法
cls_a = iter(cls_a)print(dir(cls_a)) # 具备 __iter__ 和 __next__ 方法
print(isinstance(cls_a, Iterable)) # Trueprint(isinstance(cls_a, Iterator)) # True
# 等价于 for .. in ..while True: try:        # 然后调用对象的 __next__ 方法,不断返回元素 value = next(cls_a) print(value) # 如果迭代完成,则捕获异常即可 except StopIteration: break
# 输出:1 2 3

而且 `__getitem__` 还可以通过索引直接访问元素,非常方便

a[0] # 1  a[4] # 错误,索引越界

如果你想该对象具备 list 等对象一样的长度属性,则只需要实现 `__len__` 方法即可

class A(object):    def __init__(self):        self.a = [1, 2, 3]
def __getitem__(self, item): return self.a[item]
def __len__(self): return len(self.a)
cls_a = A() print(len(cls_a)) # 3

    到目前为止,我们已经知道了第一种高级语法糖实现迭代器功能,下面分析另一个更简单的可以直接作用于函数的语法糖。

1.5 yield 生成器

    生成器是一个在行为上和迭代器非常类似的对象,二者功能上差不多,但是生成器更优雅,只需要用关键字 yield 来返回,作用于函数上叫生成器函数,函数被调用时会返回一个生成器对象,生成器本质就是迭代器,其最大特点是代码简洁。

def func():    for a in [1, 2, 3]:        yield a
cls_g = func()print(isinstance(cls_g, Iterator)) # Trueprint(dir(cls_g)) # 自动具备 __iter__ 和 __next__ 方法
for a in cls_g: print(a)
# 输出: 1 2 3
# 一种更简单的写法是用 ()cls_g = (i for i in [1,2,3])

    直观感觉和 `__getitem__` 一样,也是高级语法糖,但是比 `__getitem__` 更加简单,更加好用。

    使用 yield 函数与使用 return 函数,在执行时差别在于:包含 yield 的方法一般用于迭代,每次执行时遇到 yield 就返回 yield 后的结果,但内部会保留上次执行的状态,下次继续迭代时,会继续执行 yield 之后的代码,直到再次遇到 yield 后返回。生成器是懒加载模式,特别适合解决内存占用大的集合问题。假设创建一个包含10万个元素的列表,如果用 list 返回不仅占用很大的存储空间,如果我们仅仅需要访问前面几个元素,那后面绝大多数元素占用的空间都白白浪费了,这种场景就适合采用生成器,在迭代过程中推算出后续元素,而不需要一次性全部算出。

1.6 小结

带你从零掌握迭代器及构建最简DataLoader

  •  list set dict等内置对象都是容器 container 对象,容器是一种把多个元素组织在一起的数据结构,可以逐个迭代获取其中的元素。容器可以用 in 来判断容器中是否包含某个元素。大多数容器都是可迭代对象,可以使用某种方式访问容器中的每一个元素。

  • 在迭代对象基础上,如果实现了 `__next__`  方法则是迭代器对象,该对象在调用 next()  的时候返回下一个值,如果容器中没有更多元素了,则抛出 StopIteration 异常。

  • 对于采用语法糖 `__getitem__` 实现的迭代器对象,其本身实例既不是可迭代对象,更不是迭代器,但是其可以被 for in 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成迭代器。

  • 生成器是一种特殊迭代器,但是不需要像迭代器一样实现`__iter__`和`__next__`方法,只需要使用关键字 yield 就可以,生成器的构造可以通过生成器表达式 (),或者对函数返回值加入 yield 关键字实现。

  • 对于在类的 `__iter__` 方法中采用语法糖 yield 实现的迭代器对象,其本身实例是可迭代对象,但不是迭代器,但是其可以被 for .. in .. 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成迭代器。

     

2 DataLoader 最简版本 V1

    这里说的最简版本是指:没有任何花哨、高级实现技巧,仅仅以实现最基础功能为目的。具体来说是包括必备的5个对象:Dataset、Sampler、BatchSampler、DataLoader 和 collate_fn。其作用可以简要描述为如下:

  • Dataset 提供整个数据集的随机访问功能,每次调用都返回单个对象,例如一张图片和对应 target 等等

  • Sampler 提供整个数据集随机访问的索引列表,每次调用都返回所有列表中的单个索引,常用子类是 SequentialSampler 用于提供顺序输出的索引 和 RandomSampler 用于提供随机输出的索引

  • BatchSampler 内部调用 Sampler 实例,输出指定 `batch_size` 个索引,然后将索引作用于 Dataset 上从而输出 `batch_size` 个数据对象,例如 batch 张图片和 batch 个 target

  • collate_fn 用于将 batch 个数据对象在 batch 维度进行聚合,生成 (b,...) 格式的数据输出,如果待聚合对象是 numpy,则会自动转化为 tensor,此时就可以输入到网络中了

迭代一次伪代码如下(非迭代器版本):

class DataLoader(object):    def __init__(self):        # 假设数据长度是100,batch_size 是4        self.dataset = [[img0, target0], [img1, target1], ..., [img99, target99]]        # 假设 sampler 是 SequentialSampler,那么实际上就是 [0,1,...,99] 列表而已        # 如果 sampler 是 RandomSampler,那么可能是 [30,1,34,2,6,...,0] 列表        self.sampler = [0, 1, 2, 3, 4, ..., 99]        self.batch_size = 4        self.index = 0
def collate_fn(self, data): # batch 维度聚合数据 batch_img = torch.stack(data[0], 0) batch_target = torch.stack(data[1], 0) return batch_img, batch_target
def __next__(self): # 0.batch_index 输出,实际上就是 BatchSampler 做的事情 i = 0 batch_index = [] while i < self.batch_size: # 内部会调用 sampler 对象取单个索引 batch_index.append(self.sampler[self.index]) self.index += 1 i += 1
# 1.得到 batch 个数据了,调用 dataset 对象 data = [self.dataset[idx] for idx in batch_index]
# 2. 调用 collate_fn 在 batch 维度拼接输出 batch_data = self.collate_fn(data) return batch_data
def __iter__(self): return self

    以上就是最抽象的 DataLoader 运行流程以及和 Dataset、Sampler、BatchSampler、collate_fn 的关系。

2.1 整体对象理解

    首先需要强调的是 Dataset、Sampler、BatchSampler 和 DataLoader 都直接或间接实现了迭代器,你必须要先理解第一小节内容,否则本节内容会比较难理解,具体为:

  •  Dataset 通过实现 `__getitem__` 方法使其可迭代

  •  Sampler 对象是一个可迭代的基类对象,其常用子类 SequentialSampler 在 `__iter__` 内部返回迭代器,RandomSampler 在 `__iter__` 内部通过 yield 关键字返回迭代器

  •  BatchSampler 也是在 `__iter__` 内部通过 yield 关键字返回迭代器

  •  DataLoader 通过直接实现 `__next__` 和 `__iter__` 变成迭代器

    注意除了 DataLoader 本身是迭代器外,其余对象本身不是迭代器,但是都能被 for .. in .. 迭代。下面一个简单例子证明:

from simplev1_datatset import SimpleV1Datasetfrom libv1 import SequentialSampler, RandomSampler from collections import Iterator, Iterable   
simple_dataset = SimpleV1Dataset()  dataloader = DataLoader(simple_dataset, batch_size=2, collate_fn=default_collate)
print(isinstance(simple_dataset, Iterable)) # Falseprint(isinstance(simple_dataset, Iterator)) # Falseprint(isinstance(iter(simple_dataset), Iterator)) # True
print(isinstance(SequentialSampler(simple_dataset), Iterable)) # Trueprint(isinstance(SequentialSampler(simple_dataset), Iterator)) # Falseprint(isinstance(iter(SequentialSampler(simple_dataset)), Iterator)) # True
# BatchSampler 和 RandomSampler 内部实现结构一样,结果也是一样print(isinstance(RandomSampler(simple_dataset), Iterable)) # Trueprint(isinstance(RandomSampler(simple_dataset), Iterator)) # Falseprint(isinstance(iter(RandomSampler(simple_dataset)), Iterator)) # True
print(isinstance(dataloader, Iterator)) # True

    在 DataLoader 中主要涉及3个类,其内部实例传递关系如下:

带你从零掌握迭代器及构建最简DataLoader

    由于 DataLoader 类写的非常通用,故 Dataset、Sampler、BatchSampler 都可以外部传入,除了 Dataset 必须输入外,其余两个类都有默认实现,最典型的 Sampler 就是 SequentialSampler 和 RandomSampler。

    需要注意的是 Sampler 对象其实在大部分时候都不需要传入 Dataset 实例对象,因为其功能仅仅是返回索引而已,并没有直接接触数据。

2.2 DataLoader 运行流程

    最简单版本 DataLoader,具备如下功能:

  • Dataset 内部返回需要是 numpy 或者 tensor 对象

  • Sampler 直接 SequentialSampler 和 RandomSampler

  • BatchSampler 已经实现

  • collate_fn 仅仅考虑了 numpy 或者 tensor 对象

  • 仅仅支持 num_works=0 即单进程

看起来功能非常单一,但是其实已经搭建起了整个框架,理解了这个最简框架才能去理解高级实现,其核心运行逻辑为:

def __next__(self):    # 返回 batch 个索引    index = next(self.batch_sampler)    # 利用索引去取数据    data = [self.dataset[idx] for idx in index    # batch 维度聚合    data = self.collate_fn(data)    return data

然后为了方便大家理解,特意绘制了如下代码运行流程图:

带你从零掌握迭代器及构建最简DataLoader

    还是那句话:一定要对第1小节内容非常熟悉,否则里面这么多迭代器、生成器的调用,可能会把你绕晕。详细代码描述如下:

  1. `self.batch_sampler = iter(batch_sampler)`。在 DataLoader 的类初始化,需要得到 BatchSampler 的迭代器对象

  2. `index = next(self.batch_sampler)`。对于每次迭代,DataLoader 对象首先会调用 BatchSampler 的迭代器进行下一次迭代,具体是调用 BatchSampler 对象的  `__iter__`  方法

  3. 而 BatchSampler 对象的 `__iter__` 方法实际上是需要依靠 Sampler 对象进行迭代输出索引,Sampler 对象也是一个迭代器,当迭代 `batch_size` 次后就可以得到 `batch_size` 个数据索引

  4. `data = [self.dataset[idx] for idx in index]`。有了 batch 个索引就可以通过不断调用  dataset 的 `__getitem__` 方法返回数据对象,此时 data 就包含了 batch 个对象

  5. `data = self.collate_fn(data)`。将 batch 个对象输入给聚合函数,在第0个维度也就是 batch 维度进行聚合,得到类似 (b,...) 的对象

  6. 不断重复1-5步,就可以不断的输出一个一个 batch 的数据了

以上就是完整流程,如果理解有困难,你可以先看下一小结的代码实现,然后再返回去理解

2.3 最简V1版本源代码

 

(1) Dataset

class Dataset(object):    # 只要实现了 __getitem__ 方法就可以变成迭代器    def __getitem__(self, index):          raise NotImplementedError     # 用于获取数据集长度     def __len__(self):         raise NotImplementedError

(2) Sampler

class Sampler(object):    def __init__(self, data_source):        pass  
def __iter__(self): raise NotImplementedError
def __len__(self): raise NotImplementedError

 

class SequentialSampler(Sampler): 
def __init__(self, data_source): super(SequentialSampler, self).__init__(data_source) self.data_source = data_source
def __iter__(self): # 返回迭代器,不然无法 for .. in .. return iter(range(len(self.data_source)))
def __len__(self): return len(self.data_source)
class BatchSampler(Sampler):      def __init__(self, sampler, batch_size, drop_last):         self.sampler = sampler         self.batch_size = batch_size         self.drop_last = drop_last 
def __iter__(self): batch = [] # 调用 sampler 内部的迭代器对象 for idx in self.sampler: batch.append(idx) # 如果已经得到了 batch 个 索引,则可以通过 yield # 关键字生成生成器返回,得到迭代器对象 if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch
def __len__(self): if self.drop_last: # 如果最后的索引数不够一个 batch,则抛弃 return len(self.sampler) // self.batch_size         else:              return  (len(self.sampler) + self.batch_size - 1) // self.batch_size

(3) DataLoader

class DataLoader(object):    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,                 batch_sampler=None, collate_fn=None, drop_last=False)        self.dataset = dataset    
# 因为这两个功能是冲突的,假设 shuffle=True, # 但是 sampler 里面是 SequentialSampler,那么就违背设计思想了 if sampler is not None and shuffle: raise ValueError('sampler option is mutually exclusive with ' 'shuffle')
if batch_sampler is not None: # 一旦设置了 batch_sampler,那么 batch_size、shuffle、sampler # 和 drop_last 四个参数就不能传入 # 因为这4个参数功能和 batch_sampler 功能冲突了 if batch_size != 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler option is mutually exclusive ' 'with batch_size, shuffle, sampler, and ' 'drop_last') batch_size = None drop_last = False
if sampler is None: if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset)
# 也就是说 batch_sampler 必须要存在,你如果没有设置,那么采用默认类 if batch_sampler is None: batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.batch_size = batch_size self.drop_last = drop_last self.sampler = sampler self.batch_sampler = iter(batch_sampler)
if collate_fn is None: collate_fn = default_collate self.collate_fn = collate_fn
# 核心代码 def __next__(self): index = next(self.batch_sampler) data = [self.dataset[idx] for idx in index] data = self.collate_fn(data) return data
# 返回自身,因为自身实现了 __next__ def __iter__(self): return self

(4) collate_fn

def default_collate(batch):     elem = batch[0]      elem_type = type(elem)      if isinstance(elem, torch.Tensor):          return torch.stack(batch, 0)      elif elem_type.__module__ == 'numpy':          return default_collate([torch.as_tensor(b) for b in batch])      else:          raise NotImplementedError

(5) 调用完整例子

class SimpleV1Dataset(Dataset):     def __init__(self):          # 伪造数据          self.imgs = np.arange(0, 16).reshape(8, 2)  
def __getitem__(self, index): return self.imgs[index]
def __len__(self): return self.imgs.shape[0]

from simplev1_datatset import SimpleV1Dataset simple_dataset = SimpleV1Dataset() dataloader = DataLoader(simple_dataset, batch_size=2, collate_fn=default_collate)for data in dataloader: print(data)

 

3 总结

    本文是最小 DataLoader 系列文章的第一篇,重点是分析了 python 中迭代器相关知识,然后构建一个最简单的 DataLoader 类,用于加深到 DataLoader 流程的理解,功能比较简单。

    后面慢慢完善,希望最终能实现完整功能。

github: https://github.com/hhaAndroid/miniloader

 

推荐阅读

PyTorch 源码解读之 torch.autograd

PyTorch 源码解读之 BN & SyncBN