読者です 読者をやめる 読者になる 読者になる

About connecting the dots.

statistics/machine learning adversaria.

実装して理解するオンライン学習器(2) - Confidence-Weighted

前回からだいぶ間が空きましたが,その間なんもやってなかったので,いい加減まとめてエントリにしておきます.本当はSCWまでやってからにしたかったんですが,あきらめてCWだけで...

元ネタは前回と同じくICMLの以下の論文です.

Jialei Wang, Peilin Zhao, and Steven C. Hoi. Exact soft confidence-weighted learning. In Proc. of ICML 2012, pages 121–128, 2012.

Confidence-Weighted

モデル

オンライン学習器なので,線形モデルでかつデータ追加ごとに逐次学習を進めていくというモデルになります.CWの特徴は,各パラメタについて平均だけでなく分散も同時に求める点にあります.分散が小さければ小さいほど,より精度の高いパラメタ推定ができている,という理屈になります.

(\mu_{t+1}, \sum_{t+1})=\rm{argmax}_{\mu, \sum} \it{D}_{KL} (\mathcal{N}(\mu, \sum), \mathcal{N}(\mu_t, \sum_t))

\it{D}_{KL}はKLダイバージェンスですね*1.こんな感じで,新しいデータが与えられるごとに,KLダイバージェンスを最小にするような:\mu_{t+1}, \sum_{t+1}を求めていく形になります.

詳細な式展開は論文に譲りますが,最終的にはもう少しシンプルな形の閉形式*2であらわすことができます.あと,こちらでも更新式について書かれています.

ということで,最終的にはPassiveAggressiveと同じような形での実装が可能になります.

実装

ということで,実装式は以下の通りです.パラメタとして\etaがあるので,この値を変えることで,モデルの精度が多少変わります.

#!/usr/bin/env python
#-*-coding:utf-8-*-

from math import sqrt
import numpy as np
from scipy.stats import norm

class ConfidenceWeighted():
    def __init__(self, feat_dim, eta=0.90):
        self.t = 0
        self.m = np.ones(feat_dim)
        self.s = np.diag([1.0]*feat_dim)
        self.eta = eta
        self.phi = norm.cdf(self.eta)**(-1)
        self.psi = 1.0+(self.phi**2)/2.0
        self.zeta = 1.0+self.phi**2

    def predict(self, feats):
        return np.dot(self.m, feats)

    def update(self, y, feats):
        # parameter calculation
        v = np.dot(np.dot(feats, self.s), feats)
        m = y*(np.dot(self.m, feats))
        part = sqrt((m**2)*(self.phi**4)/4.0+v*(self.phi**2)*self.zeta)
        alpha = max(0.0, 1.0/(v*self.zeta)*(-m*self.psi+part))
        u = 0.25*((-alpha*v*self.phi+sqrt((alpha**2)*(v**2)*(self.phi**2)+4.0*v))**2)
        beta = (alpha*self.phi)/(sqrt(u)+v*alpha*self.phi)
        # update parameters
        self.t += 1
        self.m += alpha*y*np.dot(self.s, feats)
        self.s -= beta*np.dot(np.matrix(np.dot(self.s, feats).T*feats), self.s)
        return 1 if np.dot(self.m, feats) > 0 else 0

検証

前回と同じく,libsvmのテストデータから,a1aの訓練データテストデータを持ってきて使いました*3

まずはオンライン学習をさせて行ったときの精度の変化です.どのモデルでもほとんど変わらず,しかも精度も低いですね... これならPAのときのほうが精度が良いというションボリな感じの結果です.

http://f.st-hatena.com/images/fotolife/S/SAM/20150215/20150215141307_original.png

そしてテストデータに対しての予測は,こちらはそれなりに高くて70%強というところでしょうか.こちらもPAのほうが高いという...普通はCWのほうが精度がいいはずなので,どこか間違えてるのかもしれません.いろいろションボリですが,まぁ仕方なし.

http://f.st-hatena.com/images/fotolife/S/SAM/20150215/20150215141308_original.png

そんな感じのCW編でした.次はSCWにいきたいものです.

*1:一言でいうと,KLダイバージェンスは分布間の距離みたいなものです.この値が小さいほど,似た分布であると考えることができます.

*2:簡単な加減乗除や関数だけで表せる形の式のことを指します.この形にできれば計算しやすいってことです.詳細はこちらの説明をどうぞ.

*3:前処理等のためにヘルパークラスをいくつか作ってgithubにあげてあります.