ヌルヌルコンピューター

機械学習、ニューラルネットワーク、その他色々と書くつもりですが三日坊主なので続かないかもしれません。あとまだ学士も取ってないへなちょこなので書いてある情報は正確性に欠いているかもしれません。てか多分間違いがいっぱいあります。何かありましたらご指摘お願いいたします。https://github.com/hirokik0811

複数マシンでの学習をコンパクトにするDeep Gradient Compression

今日大学で議論した論文について忘れないうちにまとめておこうと思う。

Deep Gradient Compression: Reducing the Communication Bandwidth for Distributed Training
https://openreview.net/pdf?id=SkhQHMW0W

11月の初めに匿名で発表されたものだ。現在double blind で審査されている。正直言って発表されてすぐであるからリバイズされておらず、ところどころに説明不足感が否めないというのが議論に参加した全員の感想だったが、しかしほとんど既存のテクニックの組み合わせだけでdistributed trainingにおいてボトルネックとなるgradientの伝達量が大幅に圧縮されるのは面白い。

複数マシンに学習タスクを並列処理させて学習スピードを上げようというのがdistributed Trainingで、これが最初に提唱されたのがGoogleによる2012年の
http://www.cs.toronto.edu/~ranzato/publications/DistBeliefNIPS2012_withAppendix.pdf
の論文(この研究中で開発されたDistBeliefがtensorflowの前身となっている)。

f:id:May-kwi:20171124150044p:plain
from http://www.andrewng.org/portfolio/large-scale-distributed-deep-networks/

まずweightなどの学習対象のパラメータを保存するサーバー1つを、個々にNNのレプリカを持った複数の子マシンに接続する。これら複数の子マシンそれぞれに学習データを分割して入力し、それぞれ学習結果のgradientをサーバーに送信することでサーバーがパラメータを更新、その新しいパラメータを送り返して子マシンのパラメータを更新する。これが論文内で提示された並列処理の方法だ。論文中では子マシンの持つネットワークをさらに小さな部品に分割して並列化している(この論文の詳しい要約もそのうち書きます)。

しかしながら、マシン同士の送受信というのは、データ量が多いとかえって学習を遅くしてしまう原因になる。実際に大きなサイズのデータを入力とし、レイヤー数が数十、百以上に上るようなNNを学習する場合、送信されるgradient の容量は軽く数百MBになってしまう。

そこでgradientの容量を減らす工夫をしなくてはならない。DGCの論文中で、これまで提唱されてきた方法として、
Gradient Quantization
元論文https://arxiv.org/pdf/1610.02132.pdf
(まだちゃんと読んでないですけど面白い技術なんで読んだらそのうち記事書きます)
Gradient Sparsification
元論文https://arxiv.org/pdf/1704.05021.pdf
の二つをあげている。

DGCは、Gradient Sparsificationに改良を加え、gradient matrixを要素選択によって疎にするだけでなく、元のgradientも子マシン側に蓄積し加算したうえで一定の間隔ごとにサーバーに送信することで、過度な情報の消失を防ぎ正確性を高めている。

さて、まずgradient matrixを疎にするわけだが、これは非常に簡単で、matrix のノルム(論文中ではl1-normを使ってるっぽい)のs%を閾値として、これより大きいgradientのみをmatrixに残す。

数式でいうと、

 G_{t+1} = G_t + {1 \over Nb} \sum_{k=0}^N \sum_{x \in X_{minibatch}} \nabla f(x, w_t)

でgradientを取得し、

 sparse G = G \odot Mask

(Mask は閾値以上のgradientの位置を1とするbinary matrix)
によって得られた疎なgradient matrixをサーバーに毎epoch送信する。
そして、

 w_{t+1} = w_t - \eta * sparse G

によってサーバーに格納されているweightを更新する。

さらに与えられた一定期間T中、元のgradientを加算しておき、
(数式でいうと

 sum G = {1 \over NbT}  \sum^N_{k=0} \sum_{\tau=0}^{T-1} \sum_{x \in X_{minibatch}} \nabla f(x, w_{t+\tau})

)
Tごとにこれをサーバー側に送信、

 w_{t+T} = w_t - \eta T sum G

によりweight を更新する。

しかし、論文によるとsparsityを99.9%近くするとこれだけではロスが1%以上になってしまうという。
そこで正確性を高めるためさらにmomentum correction, local gradient clipping, momentum factor masking, そしてwarmup training という四つのテクニックが用いている。これらについては次回の記事で話そうと思います。


つづく