Stochastic roundingについて

目次

これはrioyokotalab Advent Calendar 2020 6日目の記事です.

Stochastic roundingとは

計算機で小数点数を扱う場合,「丸め」が必要となります.
一般的なものですと,最近接偶数方向丸め(RN)などの決定的な丸めがありますが,今回紹介するStochastic rounding(SR)は非決定的な丸め方法です.
日本語では何と言うんでしょうかね,確率的勾配降下法(SGD)に倣って「確率的丸め」でしょうか?
ではこのSRがどういった丸めかというと,名前の通り,丸め対象の仮数部を+1するかを確率的に決めます.
SRでは次の2つのmodeがあります.

  • mode 1 : 切り捨てるbit列を見て丸めとして仮数部+1を行うかを変える
  • mode 2 : 一律1/2で丸めとして仮数部+1を行うかを変える

mode 2はサイコロを振って偶数が出たら仮数部+1,奇数が出たらそのままとすればいいため,特に説明はいらないかと思います.
mode 1は切り落とすbit列と同じ長さの一様乱数を生成し,それより切り落とすbit列のほうが値として大きい又は等しければ仮数部+1を,そうでないならそのままとすることで実装することができます.

どう計算すればいいかは分かったところで,これを使うと何が嬉しいの?という話です.
例えば16bit固定小数点数+SRでCNNの学習をして,RNで丸めを行った場合と比較して精度が良いという話があったり[1],内積計算などのもっと基本的な演算で精度良く計算できたり[2]みたいな論文がここ数年ほどで出てきています.
また,Graphcore IPU等,実際のハードウェアに搭載している例もあります.
どれも低精度演算と組み合わせて使っている場合が多いように思えます.

SRの精度評価

実装

SRは本質的に乱数を必要とします.
CPUで計算するのであれば適当な疑似乱数生成関数でよいのですが,GPU上で計算したくなることも考えて,今回2通りの単純な乱数の生成方法を検討しました.

  1. 丸めを行う浮動小数点数の,切り落とさない方の仮数部を乱数bit列と解釈して使用する
  2. 1よりは演算量が増えるが,xorshiftのような簡単な疑似乱数アルゴリズムを使用する
1は,浮動小数点数の仮数部の0 or 1はiidのように振る舞うということが知られており,これを用いた方法です.
しかし,これは丸めを行う数が常に同じ場合は常に同じbit列となるため,期待通り動作しなくなります.
というわけで今回は2のxorshiftを用いて実装しました.

精度評価

今回精度評価にはベクトルの内積計算を用いました.
想定としてはNVIDIAのTF32(指数部8bit,仮数部10bit)を作る際にSRを用い,これをTensorコアのような混合精度演算で内積計算をしたらどうなるかというものです.
ですので,行う計算としては,TF32への丸め関数をr(x)として,

float sum = 0;
for (unsigned i = 0; i < N; i++) {
    sum += r(a[i]) * r(b[i])
}
のような計算とします.
丸め関数r(x)は3通り,「mode 1でのSR (sr-m1)」,「mode 2でのSR (sr-m2)」,「最近接丸め(rn)」とし,それぞれの精度を比較しました.
また,N次元入力ベクトル\(a, b\)の各要素の値は2通りで試しました.
精度の計算は,倍精度での内積計算との相対残差です.

1. ベクトルの全要素を固定値で初期化

入力ベクトル\(a\)の全要素を1,\(b\)の全要素を 0.1とします.
この場合で,ベクトル長Nを変化させたときの相対残差は下のグラフです.

mode 1のSRの精度がNが小さい時はRNと比較して良く,Nが大きくなるとNRと同程度に悪くなることがわかります.
この初期値の決め方では,RNは\(r(b[i])\)の丸め結果が常に一定となります.
一方でSR mode 1では,切り落とされる仮数部の大きさから決まる確率によって丸め方向が変わり,期待値的に内積が計算されるためにRNより精度が良くなります.

2. ベクトルの全要素を乱数で初期化

ベクトル\(a, b\)の全要素を[0-1]の一様乱数で初期化し,内積計算の相対残差を計算したグラフがこちらです.

ベクトルの初期値を固定した場合と違い,RNの精度がSR mode 1程度に良くなっています.
これは,固定値初期化とは違い,RNでの丸め方向が片方向に固定されないために起こる訳ですが, 仮数部の乱数性によりRNでも丸め方向に確率性が入った,とも考えられます.
決定的な動作をする計算機にも乱数性はとても重要なことがわかります.

おわり

浮動小数点数の仮数部の乱数性というのは,これが原因で浮動小数点数の圧縮が難しかったりするのですが,うまく付き合えれば計算精度を上げることもでき楽しい性質かなと思います.

参考文献

  • [1] Suyog Gupta, Ankur Agrawal, Kailash Gopalakrishnan, Pritish Narayanan, Deep Learning with Limited Numerical Precision, PMLR 2015
  • [2] Connolly Michael P., Higham Nicholas J., Mary Theo, Stochastic Rounding and its Probabilistic Backward Error Analysis, 2020
カテゴリー:その他
記事作成日:2020-12-06