好きな事で生きていく

統計的機械学習、因果推論、マーケティングサイエンス

2値分類で使われる損失関数、Sigmoid Cross Entropyは何を表しているのか

Deepなモデルは、最適なパラメータを学習するために損失関数を設定して、モデルの予測値と正解値の誤差を最小化します。

回帰系の問題設定でよく使われる二乗誤差などMSE系列は、回帰分析における最小二乗法とそっくりなので何をやりたいかはイメージがつくとして、分類問題におけるcross entropy系は関数を見ただけでは何をやりたいのかがわからないので、整理しました。

多値分類問題で使うsoftmax cross entropyは、あるクラスに属するか否かの分類問題を解いています。2値分類も同じくクラス0に属するのかクラス1に属するのかの分類問題を解くわけですが、クラス1に属する確率を推定するというように回帰問題としても捉えられるので、要するにモデルの最終的な出力を2変数ベクトル(2値分類)にするべきかスカラー([0,1]をとる回帰問題)にするべきかちょっとわからなくなりました。

結論から言うとどちらでやっても問題ないのですが、sigmoid cross entropyをベルヌーイ分布から導出して、確率論や最尤法的観点から損失関数を整理します。

 

0か1の値をとる、あるデータ x_iがパラメータ \thetaから発生するとします。 このような場合、確率分布はベルヌーイ分布を用います。

\begin{align}
p(x_i | \theta) = \theta ^{x_i} (1 - \theta)^{1 - x_i}
\end{align}

パラメータ \theta X=1となる確率を表し、例えば、歪みのないコインなら \theta = 1/2となり、式(1)に代入すれば、 x=1となる確率 p(X = 1 | 0.5) x=0となる確率 p(X = 0 | 0.5)も1/2となることがわかります。

 対数尤度は以下のように変形出来ます。

\begin{align}
l(\theta) = x_i \log \theta + (1 - x_i) \log (1 - \theta)
\end{align}

 最尤法ではデータxが手元にあった場合、対数尤度を最大化するパラメータ \thetaを求めますが、機械学習では観測値 y_iがあった場合の対数尤度を最大化するパラメータ \thetaを求めます。パラメータ \thetaとはこの場合予測値 f(x_i)なので、損失を最小化するようにマイナスをかけて式(2)を変形させ、データ全体の対数尤度を足し合わせると、最小化すべき損失関数は

\begin{align}
Loss(\mathbf{x}, \mathbf{y}) &= - \sum_{i} l(f(x_i), y_i) \\
&= - \sum_i \{y_i \log f(x_i) + (1 - y_i) \log (1 - f(x_i))\}
\end{align}

となり、これがSigmoid Cross Entropyとなります。 f(x_i) x_iをモデルに入力した時に出力される、モデルの予測値です。この誤差に学習率やら学習関数やらを通したものを逆伝搬させて、 f(x)の値を決めるパラメータ Wを更新していくわけです。

式(4)の f(x_i)はベルヌーイ分布のパラメータ \theta、つまり y=1となる確率に相当すると考えられるので、0と1の間にある事が望ましいです。

 

式(4)は、モデルの出力値 f(x)と観測値 yがベクトルであってもスカラー(値)であっても計算が可能です。ベクトルを用いる場合は各次元毎にlossを計算して平均をとれば、ある1つのデータに対する対数尤度(のようなもの)が計算出来ます。観測値 \mathbf{y}はどれか1つの要素が1でそれ以外の要素は0となるので、式(4)より観測値が1の要素の対数尤度のみが損失となり、観測値が0の要素は0が掛けられるので消えます。

従って、回帰としてスカラーで解いた場合でも、分類としてベクトルで解いた場合でも、観測値 yに対応した f(x)の対数尤度が損失として逆伝搬されます。

また、ベクトルが何次元であっても観測値 y=1となる次元は1つだけなので、逆伝搬する対数尤度も y=1に対応する次元の f(x)の対数尤度のみです。

ただ、ベクトル毎のlossを足し合わせる場合、それだけ値が大きくなるので、同じ学習率を使用した場合スカラーよりパラメータの更新量が大きくなる、、はずです。pytorchのBinary Cross Entropyの関数を見た所、size_averageという引数がベクトルの各要素のlossを足し合わせるのか平均をとるのかをコントロールしているようでした。

 

結論をまとめると、2値分類問題はモデルの出力をベクトル(2クラス分類)にしても、スカラー(クラス1に所属する確率を出力する回帰問題)にしても、やっていることは理論的に変わりません。