About connecting the dots.

data science related trivial things

プロ野球の投手力をMCMC+階層ベイズモデルで計算してみる

大分間が空きましたが,師走で本業が忙しくて,なかなか時間が取れませんでした.その間に時期外れになってしまった気もしますが,今さら流行に乗ってMCMC+BUGSでモデリングしてみました.ネタは,これもまた旬を逃した2013プロ野球です.

お題

よくいわれる話としてプロ野球の投手の成績というのは,実際にはチームの打線や守備に大きく左右されるわけです.最近はセイバーメトリクスなんかも盛り上がってますが,どうやってチームの実力から分離したピュアな投手の実力を測定するか,というのが大きなテーマだったりします.そんなわけで,投手と,その上位階層にあるチームを階層ベイズでモデリングしたらうまいこといかないかな,というのが今回のテーマです.

定式化

ざっくりとモデルをお絵描きすると,以下のような形になります.個人差と,その上位にあるチーム差で勝率を予測して,そこから結果の勝数が得られるというモデルです.

ちなみにWHIPは「1イニングあたり何人の走者を出したか」,DIPSは「投手がコントロールできる部門の指標」という,投手個人の指標です.またOPSは「出塁率と長打率を足し合わせた値」で,本来個人毎の指標ですが,チーム全体の平均値を取っています.DERは「本塁打以外に打たれた打球に対して,何割がアウトにできたか」,DEFは守備率を勝手に略称化しただけのもので,「選手が守備に関わった回数のうち失策をしなかった回数」です.

f:id:SAM:20131220101132j:plain

JAGS+MCMCによるモデリング

ということで,MACではWinBUGSが使えないので,JAGSを使ってMCMCのモデリングを行います.コードは以下の通りです.データはこちらこちらで公開されているものを使用させていただきました.今回の分析用にまとめなおしたデータ自体は
こちらこちらにおいてあります.

# ライブラリ読み込み
library(rjags)
library(R2jags)
library(R2WinBUGS)

# データロード
data.pitcher <- read.csv("data/pitcher.txt", sep="\t")
data.team <- read.csv("data/team.txt", sep="\t")

# モデルの記述
team_player_model <- function() {
  for (i in 1:N.pitcher) {
    WIN[i] ~ dbin(q[i], GAME[i])
    logit(q[i]) <- beta1 + beta2 * WHIP[i] + beta3 * DIPS[i] + rp[i] + rt[TEAM[i]]
  }
  beta1 ~ dnorm(0, 1.0E-4)
  beta2 ~ dnorm(0, 1.0E-4)
  beta3 ~ dnorm(0, 1.0E-4)
  for (i in 1:N.pitcher) {
    rp[i] ~ dnorm(0, tau[1])
  }
  for (j in 1:N.team) {
    rt[j] <- gamma1 + gamma2 * OPS[j] + gamma3 * DEF[j] + gamma4 * DER[j] + r[j]
  }
  gamma1 ~ dnorm(0, 1.0E-4)
  gamma2 ~ dnorm(0, 1.0E-4)
  gamma3 ~ dnorm(0, 1.0E-4)
  gamma4 ~ dnorm(0, 1.0E-4)
  for (i in 1:N.team) {
    r[i] ~ dnorm(0, tau[2])
  }
  for (k in 1:N.tau) {
    tau[k] <- 1.0 / (s[k] * s[k])
    s[k] ~ dunif(0, 1.0E+4)
  }
}

# モデルの書き出し
file.bugs <- file.path("script/baseball.bug")
R2WinBUGS::write.model(team_player_model, file.bugs)

# JAGSに渡すデータ
list.data <- list(
  WIN       = data.pitcher$win,
  GAME      = data.pitcher$game,
  WHIP      = data.pitcher$WHIP,
  DIPS      = data.pitcher$DIPS,
  TEAM      = data.pitcher$team_no,
  OPS       = data.team$OPS,
  DEF       = data.team$DEF,
  DER       = data.team$DER,  
  N.pitcher = nrow(data.pitcher),
  N.team    = nrow(data.team),
  N.tau     = 2
)

# JAGSに渡す初期値
inits <- list(
  beta1 = 0,
  beta2 = 0,
  beta3 = 0,
  gamma1 = 0,
  gamma2 = 0,
  gamma3 = 0,
  gamma4 = 0  
)

# パラメタ
params <- c("beta1", "beta2", "beta3", "gamma1", "gamma2", "gamma3", "gamma4")

# R内でのモデルの定義
model <- jags.model(
  file = file.bugs,
  data = list.data,
  inits = list(inits, inits, inits),
  n.chain = 3
)

# burn-inのためのカラまわし MCMC 計算
update(model, 100)

# MCMC 計算で事後分布からサンプリング,その結果をうけとる
post.list <- coda.samples(
  model,
  params,
  thin = 3, n.iter = 1500
)
summary(post.list)

# グラフ
plot(post.list)

# R-hatの算出(jagsパッケージでないとできないので,念のために実施)
if (1) {
  init <- function() {
    list(
      beta1 = 0,
      beta2 = 0,
      beta3 = 0,
      gamma1 = 0,
      gamma2 = 0,
      gamma3 = 0,
      gamma4 = 0  
    )
  }
  post.jags <- jags(data = list.data,
                    inits = init,
                    parameters.to.save = params,
                    n.iter = 1600,
                    model.file = file.bugs,
                    n.chains = 3,
                    n.thin = 3,
                    n.burnin = 100)
  print(post.jags)
}

# WHIPから勝利数を予測
beta1_smp <- c(post.list[,"beta1"][[1]], post.list[,"beta1"][[2]], post.list[,"beta1"][[3]])
beta2_smp <- c(post.list[,"beta2"][[1]], post.list[,"beta2"][[2]], post.list[,"beta2"][[3]])
beta3_smp <- c(post.list[,"beta3"][[1]], post.list[,"beta3"][[2]], post.list[,"beta3"][[3]])
gamma1_smp <- c(post.list[,"gamma1"][[1]], post.list[,"gamma1"][[2]], post.list[,"gamma1"][[3]])
gamma2_smp <- c(post.list[,"gamma2"][[1]], post.list[,"gamma2"][[2]], post.list[,"gamma2"][[3]])
gamma3_smp <- c(post.list[,"gamma3"][[1]], post.list[,"gamma3"][[2]], post.list[,"gamma3"][[3]])
gamma4_smp <- c(post.list[,"gamma4"][[1]], post.list[,"gamma4"][[2]], post.list[,"gamma4"][[3]])
beta1_median <- summary(post.list)$quantile[,"50%"][[1]]
beta2_median <- summary(post.list)$quantile[,"50%"][[2]]
beta3_median <- summary(post.list)$quantile[,"50%"][[3]]
gamma1_median <- summary(post.list)$quantile[,"50%"][[4]]
gamma2_median <- summary(post.list)$quantile[,"50%"][[4]]
gamma3_median <- summary(post.list)$quantile[,"50%"][[4]]
gamma4_median <- summary(post.list)$quantile[,"50%"][[4]]

plot(data.pitcher$WHIP, data.pitcher$winning_rate,
     xlim=c(0.8, 1.6),
     ylim=c(0, 1),
     xlab="WHIP",
     ylab="Winning rate")
par(new= TRUE)
for(i in sample(length(beta1_smp), 100)) {
  curve(
    1 / (1+exp(-(
    beta1_smp[i]+beta2_smp[i]*x+beta3_smp[i]*mean(data.pitcher$DIPS)+
      gamma1_smp[i]+gamma2_smp[i]*mean(data.team$OPS)+gamma3_smp[i]*mean(data.team$DER)+gamma4_smp[i]*mean(data.team$DEF)))),
    from=0.8,
    to=1.6,
    ylim=c(0, 1),
    col=rgb(0.5, 0.5, 0.5, alpha=0.3),
    xlab = "",
    ylab = ""
  )
  par(new= TRUE)
}  

結果

推定されたパラメータは以下の通りです.

> summary(post.list)

Iterations = 1103:2600
Thinning interval = 3 
Number of chains = 3 
Sample size per chain = 500 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

            Mean       SD Naive SE Time-series SE
beta1    8.67937 74.86187 1.932925       1.933725
beta2   -2.19403  1.02049 0.026349       0.045567
beta3   -0.51830  0.25645 0.006622       0.011543
gamma1  13.35963 75.33018 1.945017       1.945216
gamma2  -0.02595  0.05329 0.001376       0.001862
gamma3 -12.68709 45.35444 1.171047       2.015237
gamma4  -7.78060 13.10682 0.338417       0.541950

2. Quantiles for each variable:

            2.5%       25%       50%       75%     97.5%
beta1  -139.9095 -40.53518   9.11984 59.542346 151.18492
beta2    -4.1664  -2.86411  -2.22075 -1.533479  -0.17760
beta3    -1.0175  -0.69404  -0.51936 -0.336695  -0.02209
gamma1 -135.3303 -35.20230  10.92927 62.833525 158.25999
gamma2   -0.1471  -0.05414  -0.02456  0.008814   0.07215
gamma3  -97.8243 -44.09265 -15.06952 16.472134  80.79574
gamma4  -32.2971 -15.66428  -8.16691 -0.095224  18.64349

そして,WHIPと勝率をプロットして,事後シミュレーションを行ったのが下の図になります.はっきりいって,ろくろくまともに予測できていません orz... 正直原因はよくわかっていないのですが,サンプル数が少ないのか,はずれ値に影響されすぎているのか... もう少しちゃんと勉強して出直してきます.

f:id:SAM:20131220103118j:plain