从零开始学习MXnet(五)MXnet的黑科技之显存节省大法

  写完发现名字有点拗口。。- -#

  大家在做deep learning的时候,应该都遇到过显存不够用,然后不得不去痛苦的减去batchszie,或者砍自己的网络结构呢? 最后跑出来的效果不尽如人意,总觉得自己被全世界针对了。。遇到这种情况怎么办? 请使用MXnet的天奇大法带你省显存! 鲁迅曾经说过:你不去试试,怎么会知道自己的idea真的是这么糟糕呢?

  首先是传送门附上 mxnet-memonger,相应的paper也是值得一看的 Training Deep Nets with Sublinear Memory Cost

  实际上repo和paer里面都说的很清楚了,这里简单提一下吧。

  一、Why?

  节省显存的原理是什么呢?我们知道,我们在训练一个网络的时候,显存是用来保存中间的结果的,为什么需要保存中间的结果呢,因为在BP算梯度的时候,我们是需要当前层的值和上一层回传的梯度一起才能计算得到的,所以这看来显存是无法节省的?当然不会,简单的举个例子:一个3层的神经网络,我们可以不保存第二层的结果,在BP到第二层需要它的结果的时候,可以通过第一层的结果来计算出来,这样就节省了不少内存。  提醒一下,这只是我个人的理解,事实上这篇paper一直没有去好好的读一下,有时间在再个笔记。不过大体的意思差不多就是这样。

  

  二、How?

  怎么做呢?分享一下我的trick吧,我一般会在symbol的相加的地方如data = data+ data0这种后面加上一行 data._set_attr(force_mirroring='True'),为什么这么做大家可以去看看repo的readme,symbol的地方处理完以后,只有如下就可以了,searchplan会返回一个可以节省显存的的symbol给你,其它地方完全一样。

  

 1 import mxnet as mx
 2 import memonger
 3 
 4 # configure your network
 5 net = my_symbol()
 6 
 7 # call memory optimizer to search possible memory plan.
 8 net_planned = memonger.search_plan(net)
 9 
10 # use as normal
11 model = mx.FeedForward(net_planned, ...)
12 model.fit(...)

  PS:使用的时候要注意,千万不要在又随机性的层例如dropout后面加上mirror,因为这个结果,再算一次就和上一次不同了,会让你的symbol的loss变得很奇怪。。

 

三、总结

  天奇大法吼啊!