東京に棲む日々

データ分析、統計、ITを勉強中。未だ世に出ず。

Gradient Boosting Treeを使ってみる - R{gbm} (Part.1)

予測モデリングのコンペで良く使われるらしいGBMを業務で使うことになったので、その使い方メモ。

Rのbgmを使う。

理論的なことは↓をじっくりといつか振り返ることにして、とりあえず使ってみる。
Ridgeway(2012), Generalized Boosted Models: A guide to the gbm package

 

データはUCIのページより、Adult Data Setというのを使用。

 

32,562行15列の、3.35Mくらいのデータ。

目的変数がincome_classで、水準が"<=50K"と">50K"の2水準。
よって、2値判別問題のデータ。

説明変数が以下の14変数
age
workclass
fnlwgt
education
education-num
marital-status
occupation
relationship
race
sex
capital-gain
capital-loss
hours-per-week
native-country


カテゴリカル変数で水準数が少ないものは他の水準に統合しておく。

workclass変数(9水準)。
"Never-worked"と"Without-pay"を、"?"と変更。

marital-status変数(7水準)。
"Married-AF-spouse"と"Married-civ-spouse"と"Married-spouse-absent"とを、"Married"と変更。

race変数(5水準)。
"Amer-Indian-Eskimo"を、"Other"と変更。

occupation変数(15水準)。
"Armed-Forces"と"Priv-house-serv"を、"?"と変更。


adult2_0に読み込む。

str(adult2_0)

'data.frame':   32561 obs. of  15 variables:
 $ age           : int  39 50 38 53 28 37 49 52 31 42 ...
 $ workclass     : Factor w/ 7 levels "?","Federal-gov",..: 7 6 4 4 4 4 4 6 4 4 ...
 $ fnlwgt        : int  77516 83311 215646 234721 338409 284582 160187 209642 45781 159449 ...
 $ education     : Factor w/ 16 levels "10th","11th",..: 10 10 12 2 10 13 7 12 13 10 ...
 $ education.num : int  13 13 9 7 13 14 5 9 14 13 ...
 $ marital.status: Factor w/ 5 levels "Divorced","Married",..: 3 2 1 2 2 2 2 2 3 2 ...
 $ occupation    : Factor w/ 13 levels "?","Adm-clerical",..: 2 4 6 6 9 4 8 4 9 4 ...
 $ relationship  : Factor w/ 6 levels "Husband","Not-in-family",..: 2 1 2 1 6 6 2 1 2 1 ...
 $ race          : Factor w/ 4 levels "Asian-Pac-Islander",..: 4 4 4 2 2 4 2 4 4 4 ...
 $ sex           : Factor w/ 2 levels "Female","Male": 2 2 2 2 1 1 1 2 1 2 ...
 $ capital.gain  : int  2174 0 0 0 0 0 0 0 14084 5178 ...
 $ capital.loss  : int  0 0 0 0 0 0 0 0 0 0 ...
 $ hours.per.week: int  40 13 40 40 40 40 16 45 50 40 ...
 $ native.country: Factor w/ 42 levels "?","Cambodia",..: 40 40 40 40 6 40 24 40 40 40 ...
 $ income_class  : Factor w/ 2 levels "<=50K",">50K": 1 1 1 1 1 1 1 2 2 2 ...

 

目的変数の水準"<=50K"を0、">50K"を1と数値に変換しておく。

icm <- as.numeric(adult2_0$income_class==">50K")

変数名をincomeとする。

adult2_1 <- data.frame(adult2_0, income=icm)

str(adult2_1)

'data.frame':   32561 obs. of  16 variables:
 $ age           : int  39 50 38 53 28 37 49 52 31 42 ...
 $ workclass     : Factor w/ 7 levels "?","Federal-gov",..: 7 6 4 4 4 4 4 6 4 4 ...
 $ fnlwgt        : int  77516 83311 215646 234721 338409 284582 160187 209642 45781 159449 ...
 $ education     : Factor w/ 16 levels "10th","11th",..: 10 10 12 2 10 13 7 12 13 10 ...
 $ education.num : int  13 13 9 7 13 14 5 9 14 13 ...
 $ marital.status: Factor w/ 5 levels "Divorced","Married",..: 3 2 1 2 2 2 2 2 3 2 ...
 $ occupation    : Factor w/ 13 levels "?","Adm-clerical",..: 2 4 6 6 9 4 8 4 9 4 ...
 $ relationship  : Factor w/ 6 levels "Husband","Not-in-family",..: 2 1 2 1 6 6 2 1 2 1 ...
 $ race          : Factor w/ 4 levels "Asian-Pac-Islander",..: 4 4 4 2 2 4 2 4 4 4 ...
 $ sex           : Factor w/ 2 levels "Female","Male": 2 2 2 2 1 1 1 2 1 2 ...
 $ capital.gain  : int  2174 0 0 0 0 0 0 0 14084 5178 ...
 $ capital.loss  : int  0 0 0 0 0 0 0 0 0 0 ...
 $ hours.per.week: int  40 13 40 40 40 40 16 45 50 40 ...
 $ native.country: Factor w/ 42 levels "?","Cambodia",..: 40 40 40 40 6 40 24 40 40 40 ...
 $ income_class  : Factor w/ 2 levels "<=50K",">50K": 1 1 1 1 1 1 1 2 2 2 ...
 $ income        : num  0 0 0 0 0 0 0 1 1 1 ...

 

incomeが1の割合の確認。

table(adult2_1$income)["1"]/sum(table(adult2_1$income))
0.2408096

 

学習、検証、テストデータに分けるため、行をシャッフルする。
adult2_2 <- adult2_1[sample(1:nrow(adult2_1)),]

 

train(学習と検証)、test(テスト)へデータを分割。

nrow(adult2_2) * 0.75
24420.75

trainは1から24420行目までを、24421行目から32561行目までの8141行をtestとする。

train0 <- adult2_2[1:24420,]          # 24420 obs. of 16 variables
test0 <- adult2_2[24421:32561,]          # 8141 obs. of 16 variables


incomeの割合の確認。

test。
table(train0$income)["1"]/sum(table(train0$income))
0.2410729

train。
table(test0$income)["1"]/sum(table(test0$income))
0.2400197

 

gbmの実行。


library(gbm)

gbmでなく、gbm.fitを使う。

 

説明変数と目的変数を分けておく。

x_variables <- train0[,1:14]
target <- train0$income

 

データの2/3を学習、1/3を検証とする。
for_train <- round(nrow(train0)*0.666)
16264

1行目から16264行目までが学習、残りが検証とgbm.fitでは処理される(たぶん...)。

 

mod1 <- gbm.fit(
     x=x_variables,
     y=target,
     distribution = "bernoulli",
     bag.fraction = 0.5,
     n.trees = 5000,
     interaction.depth = 7,
     n.minobsinnode = 100,
     shrinkage = 0.005,
    nTrain = for_train
)

実行すると、5000回のイテレーション結果が画面出力される。


distribution:
目的変数が2値(0/1)判別の場合は"bernoulli"と指定。

bag.fraction:
イテレーションで、オブザベーションの重複無しサンプリングを行う割合。Ridgeway(2012)によると、この処理を行うと、パフォーマンスがけっこう上がるとのこと。

n.tree:
イテレーション回数(合計tree数)。

interaction.depth:
各treeにおける分岐数。
6-8くらい良いとのこと。(同僚よりの情報)

n.minobsinnode:
分岐における最低オブザベーション数(行数)。これより少ないオブザベーショングループができてしまう分岐は行わない、と言うことだと思う。

shrinkage:
学習率。
デフォルトは0.01だが、0.001とか0.005とか小さくしたほうが良いとのこと。(同僚よりの情報)

nTrain:
学習データの行数。1行目からnTrain行までが学習データとなり、残りが検証データとなる。

 

 

変数重要度の表示。
summary(mod1) 

                          var    rel.inf
relationship     relationship 29.7713982
capital.gain     capital.gain 15.9408432
occupation         occupation 13.5176317
education           education  9.6065983
age                       age  6.0695892
education.num   education.num  5.0654549
native.country native.country  4.8932537
capital.loss     capital.loss  4.1501917
hours.per.week hours.per.week  3.9770395
fnlwgt                 fnlwgt  2.8375266
workclass           workclass  1.9443323
marital.status marital.status  1.7874420
sex                       sex  0.2690007
race                     race  0.1696981

 

検証データのBernoulli Deviance(Ridgeway(2012)に記載あり)を最小にするTree数が選ばれる。

 

では、この最適な数のTree数によるモデルで、予測を行う。

 

学習データの説明変数と目的変数。
x_variables_train <- x_variables[1:16264,]           # 先頭から16264行まで
target_train <- target[1:16264]

検証データの説明変数と目的変数。
x_variables_varidation <- x_variables[16265:nrow(x_variables),]          # 16265行以降
target_varidation <- target[16265:nrow(x_variables)]


学習データに対する予測値。
pred_mod1_train <- predict(mod1, newdata=x_variables_train, n.trees=gbm.perf(mod1, plot.it=FALSE),
type="response")

検証データに対する予測値。
pred_mod1_varidation <- predict(mod1, newdata=x_variables_varidation, n.trees=gbm.perf(mod1,
plot.it=FALSE), type="response")


テストデータの説明変数と目的変数。
test_x <- test0[,1:14]
test_y <- test0[,16]

テストデータに対する予測値。
pred_mod1_test <- predict(mod1, newdata=test_x, n.trees=gbm.perf(mod1, plot.it=FALSE),
type="response")

 

 

AUCを計算する。

gbm.roc.areaという関数があるようだが、ROCのAUCを計算する関数を過去に書いたことがあるので、それを使う。

highschoolstudent.hatenablog.com

 

ROCRパッケージを読み込み、fn_AUC関数を読んでおく。

############ ROC ############

library(ROCR)

# AUCを返す関数 - pred:予測値, obs:実測値
fn_AUC <- function(pred, obs){

     pred1 <- prediction(predictions=pred, labels=obs)
     auc.temp <- performance(pred1, "auc")
     auc <- unlist(auc.temp@y.values)
     return(auc)
}

#############################

 

学習データのAUC。
fn_AUC(pred=pred_mod1_train, obs=target_train)
0.9344421

検証データのAUC。
fn_AUC(pred=pred_mod1_varidation, obs=target_varidation)
0.9232262

テストデータのAUC。
fn_AUC(pred=pred_mod1_test, obs=test_y)
0.9190125 

 

 

最後にPartial Dependence Plotを書き、各変数と目的変数の関係を見ておく。

Partial Dependence Plotだが、正式名称かどうかは不明。同僚がそのように呼んでいる。
また、詳しい計算方法も把握してはいない。

 

ヒストグラム、棒グラフも同時にプロットする関数を定義しておく。

# gbm.plot(partial dependence plot)とhistgram/barplotを並べる関数
# mod_res:gbmの結果, num_trees:Treeの数, var_name:調べる変数名, x_data:元の説明変数が入った data.frame
fn_plot_gbm <- function(mod_res, num_trees, var_name, x_data){

     if(class(x_data[,var_name]) %in% c("numeric","integer")) {      # histの場合
          par(mfrow=c(2,1))
          plot(mod_res, i.var=var_name, ntrees=num_trees , type="response")
          hist(x_data[,var_name])

     } else if(class(x_data[,var_name])=="factor") { # barplotの場合
          par(mfrow=c(2,1))
          plot(mod_res, i.var=var_name, ntrees=num_trees , type="response")
          barplot(table(x_data[,var_name]), las=3)

     } else {
          stop("variable class must be numeric or factor.")
     }

     par(mfrow=c(1,1))
}


このプロットを見て、学習データのノイズに過剰に反応して、オーバーフィッティングが起きていないか調べる。

 

1)   age(数値)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="age", x_data=x_variables)

f:id:High_School_Student:20150627143242j:plain

60まではmonotone increaseに見える。60くらいから動きが変わる。


2)   workclass(カテゴリー)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="workclass", x_data=x_variables)

f:id:High_School_Student:20150627143342j:plain

 

3)   fnlwgt(数値)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="fnlwgt", x_data=x_variables)

f:id:High_School_Student:20150627143419j:plain

 500000くらいを境に外れ値。

 

4)   education(カテゴリー)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="education", x_data=x_variables)

f:id:High_School_Student:20150627143453j:plain


5)   education.num(数値)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="education.num", x_data=x_variables)

f:id:High_School_Student:20150627143526j:plain

 monotone increasingのようである。

 

6)   marital.status(カテゴリー)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="marital.status", x_data=x_variables)

f:id:High_School_Student:20150627143600j:plain


7)   occupation(カテゴリー)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="occupation", x_data=x_variables)

f:id:High_School_Student:20150627143636j:plain

 

8)   relationship(カテゴリー)

fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="relationship", x_data=x_variables)

f:id:High_School_Student:20150627143704j:plain

 

9)   race(カテゴリー)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="race", x_data=x_variables)

f:id:High_School_Student:20150627143733j:plain

 

10)   sex(カテゴリー)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="sex", x_data=x_variables)

f:id:High_School_Student:20150627143804j:plain

 

11)   capital.gain(数値)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="capital.gain", x_data=x_variables)

f:id:High_School_Student:20150627143836j:plain

10000くらいを境に外れ値。monotone increasingに見える。


12)   capital.loss(数値)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="capital.loss", x_data=x_variables)

f:id:High_School_Student:20150627143908j:plain

元の変数の意味自体が不明で残念であるが、1500から2000くらいの間で急に反応が起こる。

 

13)   hours.per.week(数値)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="hours.per.week", x_data=x_variables)

f:id:High_School_Student:20150627143951j:plain

80くらいの境で外れ値。monotone increasingっぽく見える。

 

14)   native.country(カテゴリー)
fn_plot_gbm(mod_res=mod1, num_trees=gbm.perf(mod1, plot.it=FALSE), var_name="native.country", x_data=x_variables)

f:id:High_School_Student:20150627144030j:plain

カテゴリカル変数の多水準問題(High Cardinality Problemと同僚は呼んでいる)。
学習データにおいて、沢山水準があるなかで、たまたまいずれかの水準の影響が大きくなってしまい、オーバーフィッテングを起こす可能性がある。
国内のデータを扱っていても、都道府県をそのまま扱うか迷うときがある。
何らかの方法で数値型変数に変換して、説明変数に入れるのが良いとのこと。(同僚曰く)

 

 次回は、Partial Dependence Plotで気づいたことを元に、いくつかの説明変数の変換方法を試してみる。