xgboostを使う

概要

xgboostを触る用事ができたので少しだけ触ってみた。

試しにTwitterからプロ野球に関するTweetを取得し、ハッシュタグを正解データとしてどの球団の内容を呟いているかのラベルを付け、Tweetの本文からそれがプロ野球のどの球団のTweetかを特定するモデルを構築する、みたいなことをしてみる。

当初は同様の方法でプログラミング言語の分類をしようかと思ったんだけど、サンプル数集めるのに時間がかかりそうだったので断念。プロ野球ならシーズン中に1週間も収集処理を回しておけばそこそこデータが集まるので題材としては便利である。

@CretedDate 2016/08/14
@Versions xgboost0.6, python3.5, mecab0.996

Tweetの取得

TwitterのSearchAPIで各球団名や有名選手名、「野球」「ナイター」「二塁打」「本塁打」などの野球に関連するTweetだけが引っかかりそうな単語で検索し、出現するハッシュタグをカウントする。

取得したタグの一覧がこちら。

197,590tweetほど収集して(重複Tweet及びRTは除外)、出現するタグをカウントすると、上位は下記のようになった。

(carp,5681)
(baystars,5266)
(giants,5181)
(トレクル,4858)
(npb,4592)
(TreCru,4293)
(dragons,4244)
(swallows,3656)
(sbhawks,3584)
(Giants,3275)
(lovefighters,3233)
(chibalotte,2904)
(seibulions,2838)
(プロ野球,2583)
(Lions,2327)
(巨人,2250)
(kyojin,2173)
(カープ,1996)
(阪神,1735)
(tigers,1445)

野球関連のタグがちゃんと取れている。なにやらトレクル(スマホゲーム)が上位に引っかかってしまっているけど、これはおそらく適当に設定した何かの検索ワードが用語的に被ってしまったのだと思われる。

Twitter4jの検索APIの叩き方についてはこちらのページにまとめたので実際に動かしてみたい方は参照されたし。

こうやって得たタグの一覧を手動でチャカチャカして、下記のようなタグと球団との関連表を作る。中日、ヤクルト、日本ハム、広島といった野球以外の意味でも用いられそうなタグは入れていない。

広島:carp,カープ,広島東洋カープ
横浜:baystars,ベイスターズ,横浜ベイスターズ
巨人:giants,Giants,巨人,kyojin,ジャイアンツ
中日:dragons,中日ドラゴンズ,ドラゴンズ,Dragons
ヤクルト:swallows
ソフトバンク:sbhawks,ソフトバンクホークス
日ハム:lovefighters,fighters,日ハム
ロッテ:chibalotte
西武:seibulions,Lions,lions
阪神:tigers,阪神,hanshin,タイガース,阪神タイガース
オリックス:Orix_Buffaloes,buffaloes,バファローズ
楽天:楽天イーグルス,rakuteneagles,eagles

上記のタグを使ってさらにtweetを集めた上で、下記のようなコードでチーム別にtweetをまとめてファイルを作成する。Twitter4jで集めたデータはArray[Status]の形でシリアライズされて保存されているものとする。

package jp.mwsoft.twitter4j.example

import resource._
import java.io.FileInputStream
import java.io.ObjectInputStream
import twitter4j.Status
import java.io.File
import scala.collection.mutable.HashSet
import scala.collection.mutable.HashMap
import java.io.Writer
import java.io.BufferedWriter
import java.io.FileWriter

object Wakati extends App {

  // ハッシュタグとチームの関連表
  val tagTeams = Map(
    "carp" -> "carp",
    "カープ" -> "carp",
    "広島東洋カープ" -> "carp",
    "baystars" -> "baystars",
    "ベイスターズ" -> "baystars",
    "横浜ベイスターズ" -> "baystars",
    "giants" -> "giants",
    "巨人" -> "giants",
    "kyojin" -> "giants",
    "ジャイアンツ" -> "giants",
    "dragons" -> "dragons",
    "中日ドラゴンズ" -> "dragons",
    "ドラゴンズ" -> "dragons",
    "swallows" -> "swallows",
    "スワローズ" -> "swallows",
    "sbhawks" -> "hawks",
    "ソフトバンクホークス" -> "hawks",
    "lovefighters" -> "fighters",
    "fighters" -> "fighters",
    "日ハム" -> "fighters",
    "chibalotte" -> "marines",
    "seibulions" -> "lions",
    "lions" -> "lions",
    "tigers" -> "tigers",
    "阪神" -> "tigers",
    "hanshin" -> "tigers",
    "タイガース" -> "tigers",
    "阪神タイガース" -> "tigers",
    "orix_buffaloes" -> "buffaloes",
    "buffaloes" -> "buffaloes",
    "バファローズ" -> "buffaloes",
    "楽天イーグルス" -> "eagles",
    "rakuteneagles" -> "eagles",
    "eagles" -> "eagles")

  // Twitter4jから集めてシリアライズして保存してあったファイルを読み込む
  def readObject(path: File): Array[Status] = {
    (for (is <- managed(new ObjectInputStream(new FileInputStream(path))))
      yield is.readObject().asInstanceOf[Array[Status]]).opt.get
  }

  // idの重複があるのでHashMapで全Tweetを保持する(メモリに乗り切らないほど集めた時は書き直し)
  val tweets = new HashMap[Long, Status]()
  for (path <- new File("tweets").listFiles.slice(0, 10000)) {
    println(path)
    val fileTweets = readObject(path)
    for (tweet <- fileTweets if !tweet.isRetweet)
      tweets += tweet.getId -> tweet
  }

  // 各球団ごとにファイルのwriterを作る
  val writers = new HashMap[String, Writer]()
  // tweetからURLは除去する
  def removeUrl(tweet: String): String = tweet.replaceAll("http(s{0,1})://[a-zA-Z0-9_/\\-\\.]+\\.([A-Za-z/]{2,5})[a-zA-Z0-9_/\\&\\?\\=\\-\\.\\~\\%]*", "")
  // tweetからタグは除去する
  def tweetWithoutTags(status: Status): String =
    status.getHashtagEntities.foldLeft(removeUrl(status.getText))((tweet, tag) => tweet.replaceAllLiterally("#" + tag.getText, ""))
  // 対応するチームのファイルにtweetを書き出す
  def writeFile(team: String, status: Status) {
    if (!writers.contains(team)) writers += team -> new BufferedWriter(new FileWriter("team_tweets/" + team + ".txt"))
    writers(team).write(tweetWithoutTags(status) + "\n")
  }
  // tweetとtagの内容を確認して、指定チームを探して書き出す
  // 複数チームがタグに入っていることもあるけど、その場合は最初にヒットしたチームに入れている
  def writeTeamTweet(status: Status, tags: Iterator[String]) {
    val tag = tags.next()
    if (tagTeams.contains(tag))
      writeFile(tagTeams(tag), status)
    else if (tags.hasNext)
      writeTeamTweet(status, tags)
  }

  new File("team_tweets").mkdir()
  for ((id, status) <- tweets) {
    val tags = status.getHashtagEntities.map(_.getText.toLowerCase).iterator
    if (tags.hasNext) writeTeamTweet(status, tags)
  }

  // 後処理
  for ((team, writer) <- writers) writer.close()
}

タグ自体は取り除き、1tweet=1行になるように改行を除去し、URLも邪魔になりそうなので除去している。

これでteam_tweetsというディレクトリ配下に、チームごとのtweetが出力された。

形態素解析

生成したファイルをわかち書きする。モデルに利用する単語は各タグに対してtf/idfで上位N語としてみる。

下記は形態素解析して単語を抜き出す処理をPythonで書いたもの。先ほどの処理で作ったteam_tweetsディレクトリ配下の各ファイルを読み込み、wakatiディレクトリ配下に分かち書き結果を出力する。

import os, re, MeCab

# 分かち書きしたファイル格納用のディレクトリ
if not os.path.exists('wakati'):
    os.mkdir('wakati')

tagger = MeCab.Tagger()
tagger.parseToNode('') # おまじない

def lineWakatiWriter(line, writer):
    node = tagger.parseToNode(line)
    line = ''
    while node:
        # 名詞で、2文字以上で、asciiキャラ以外の文字が入っている単語だけ利用
        if node.feature.startswith('名詞') and len(node.surface) > 1 and not all(ord(c) < 128 for c in node.surface):
            line += unicodedata.normalize('NFKC', node.surface) + ' '
        node = node.next
    if len(line) > 0:
        writer.write(line + '\n')

for file in os.listdir('team_tweets'):
    with open('team_tweets/' + file, 'rt') as reader, open('wakati/' + file, 'wt') as writer:
        for line in reader:
            lineWakatiWriter(line, writer)

tf-idfで効いてそうな単語を見る

tf-idfで特徴的な言葉だけを見て、何が効いていそうか様子を見ておく。なお、楽天とオリックスのtweet数が十分に取れなかったのでこの2チームは省いて10チームで分析することとした。「楽天」「オリックス」で検索すると企業関連の言葉も出てきてしまうので取りづらかった。

# scikit-learnを利用してtf-idfを抽出
from sklearn.feature_extraction.text import TfidfVectorizer
tfidf_vectorizer = TfidfVectorizer(input='filename', max_df=0.8, min_df=1, norm='l2')
files = ['wakati/' + path for path in os.listdir('wakati')]
tfidf = tfidf_vectorizer.fit_transform(files)

# 各チームごとに上位10語だけ取る
for i in range(tfidf.shape[0]):
    df = pd.DataFrame(np.array([tfidf_vectorizer.get_feature_names(), tfidf[i].toarray()[0, :]]).T, columns=['word', 'score'])
    words = df.sort_values('score', ascending=False).word[0:10]
    print(files[i], words.values)

実行結果はこちら。ちゃんとそれっぽい言葉が取れてる。これならxgboostに食わせればちゃんとした判定器作れそう。

wakati/fighters.txt ['札幌ドーム' '北海道日本ハムファイターズ' 'ファイターズ' '大谷翔平' 'むほ' '杉谷' '西武ライオンズ' '在籍' 'レアード' '陽岱鋼']
wakati/hawks.txt ['福岡ソフトバンクホークス' '明石' '千賀' '福岡' 'ヤフオク' 'ドーム' '五十嵐' '柳田' '武田' '摂津']
wakati/dragons.txt ['中日ドラゴンズ' 'ナゴヤドーム' '川柳' '谷繁' '中日スポーツ' '吉見' 'ジョーダン' '近藤' '平田' 'ビシエド']
wakati/tigers.txt ['タイガース' '阪神タイガース' '鳥谷' 'ゴメス' '福留' '岩貞' '兵庫県' '金本' '原口' '京セラドーム大阪']
wakati/lions.txt ['ライオンズ' '山川' '西武プリンスドーム' '野上' '栗山' 'メヒア' '埼玉西武ライオンズ' '牧田' '山川穂高' '西武プリンス']
wakati/baystars.txt ['横浜denaベイスターズ' 'dena' '梶谷' 'ベイ' '今永' 'ニコ生' '白崎' '桑原' '倉本' '山口']
wakati/marines.txt ['西野' 'マリーンズ' '涌井' '千葉ロッテマリーンズ' '荻野' '平沢' '細谷' '益田' 'オーオー' '伊東']
wakati/carp.txt ['新井' '東京都' '石原' '黒田' '巨人戦' 'ルナ' '広島カープ' '神宮' 'エルドレッド' 'まくら']
wakati/swallows.txt ['東京ヤクルトスワローズ' '神宮球場' 'スワローズ' '西田' '神宮' 'ツバメ' 'バレンティン' '山田哲人' '山中' '真中']
wakati/giants.txt ['東京都' '巨人戦' '脇谷' '調査兵団' '田口' '長野' '村田' '小林' '坂本' '巨人ファン']

このまま利用しようとすると、次元数が10 * 19947とかになるので、ちょっと大きい気もする。

len(tfidf_vectorizer.get_feature_names())
  #=> 19947

無駄は少し省こう。上位の5000語とかにしておく。

# max_features=5000で再度tf-idf
tfidf_vectorizer = TfidfVectorizer(input='filename', max_df=0.8, min_df=1, norm='l2', max_features=5000)
files = ['wakati/' + path for path in os.listdir('wakati')]
tfidf = tfidf_vectorizer.fit_transform(files)

# 各チームごとに上位10語だけ取る
for i in range(tfidf.shape[0]):
    df = pd.DataFrame(np.array([tfidf_vectorizer.get_feature_names(), tfidf[i].toarray()[0, :]]).T, columns=['word', 'score'])
    words = df.sort_values('score', ascending=False).word[0:10]
    print(files[i], words.values)

5000語に減らしても取得できる単語に特に問題はなさそう。

wakati/fighters.txt ['札幌ドーム' '北海道日本ハムファイターズ' 'ファイターズ' '大谷翔平' 'むほ' '杉谷' '西武ライオンズ' '在籍' 'レアード' '陽岱鋼']
wakati/hawks.txt ['福岡ソフトバンクホークス' '明石' '千賀' '福岡' 'ヤフオク' 'ドーム' '五十嵐' '柳田' '武田' '摂津']
wakati/dragons.txt ['中日ドラゴンズ' 'ナゴヤドーム' '川柳' '谷繁' '中日スポーツ' '吉見' 'ジョーダン' '近藤' '平田' 'ビシエド']
wakati/tigers.txt ['タイガース' '阪神タイガース' '鳥谷' 'ゴメス' '福留' '岩貞' '兵庫県' '金本' '原口' '京セラドーム大阪']
wakati/lions.txt ['ライオンズ' '山川' '西武プリンスドーム' '野上' '栗山' 'メヒア' '埼玉西武ライオンズ' '牧田' '山川穂高' 'ポーリーノ']
wakati/baystars.txt ['横浜denaベイスターズ' 'dena' '梶谷' 'ベイ' '今永' 'ニコ生' '白崎' '桑原' '倉本' '山口']
wakati/marines.txt ['西野' 'マリーンズ' '涌井' '千葉ロッテマリーンズ' '荻野' '平沢' '細谷' '益田' 'オーオー' '伊東']
wakati/carp.txt ['新井' '東京都' '石原' '黒田' '巨人戦' 'ルナ' '広島カープ' '神宮' 'エルドレッド' 'まくら']
wakati/swallows.txt ['東京ヤクルトスワローズ' '神宮球場' 'スワローズ' '西田' '神宮' 'ツバメ' 'バレンティン' '山田哲人' '山中' '真中']
wakati/giants.txt ['東京都' '巨人戦' '脇谷' '調査兵団' '田口' '長野' '村田' '小林' '坂本' '巨人ファン']

xgboostに食わせる形式のデータを作成する(binary)

全球団の分類をする前に、単純なbinary classificationを試してみる。カープだけ分類するモデルを作ってみよう。

上位5000の単語で辞書を作って各単語をIDにして、単純にcarpなら1、それ以外は0のラベルを貼り、各tweetをxgboostに読み込める形式に変換する。

# 辞書を作っておく
dic = dict((i, s) for i, s in enumerate(tfidf_vectorizer.get_feature_names()))
reverse_dic = dict((s, i) for i, s in enumerate(tfidf_vectorizer.get_feature_names()))

# 5000件あることを確認
len(dic)
  #=> 5000

def mkLine(line):
    words = filter(lambda w: w in reverse_dic, [word for word in line.split(' ')])
    ids = sorted([reverse_dic[word] for word in set(words)])
    return ' '.join(['{0}:1.0'.format(id) for id in ids])

def fileWriter(file, writer):
    # 今回はカープ分類器なので、カープだけ1で他は0のラベルを振る
    label = 1 if file == 'carp.txt' else 0
    with open('wakati/' + file, 'rt') as reader:
        team = file
        for line in reader:
            line = mkLine(line)
            if len(line) > 0:
                writer.write('{0} {1}\n'.format(label, line))

with open('tweets.txt', 'wt') as writer:
    for file in os.listdir('wakati'):
        fileWriter(file, writer)

これでこんな感じのファイルが出来上がる。

0 1990:1.0 6253:1.0
0 4295:1.0
0 2552:1.0
0 1155:1.0 6137:1.0 6437:1.0 9135:1.0 9313:1.0
0 8757:1.0
0 1859:1.0
0 1277:1.0 1567:1.0 1585:1.0
0 8305:1.0
1 6806:1.0 7912:1.0
1 6625:1.0
1 7770:1.0 8683:1.0

これで広島カープ分類器を作る準備ができた。

広島東洋カープ分類問題

上で作ったデータで学習をしてみる。80%を訓練データ、20%をテストデータに設定してさらっと。

import numpy as np
import xgboost as xgb
from sklearn.datasets import load_svmlight_file

x, y = load_svmlight_file("tweets.txt")
rnd = np.random.random(len(y))
train_x = x[rnd < 0.8]
train_y = y[rnd < 0.8]
test_x = x[rnd >= 0.8]
test_y = y[rnd >= 0.8]
xg_train = xgb.DMatrix(train_x, label=train_y)
xg_test = xgb.DMatrix(test_x, label=test_y)

# setup parameters for xgboost
param = {
    'objective':'binary:logistic',
    'eta': 0.1,
    'max_depth': 5,
    'nthread': 4
}

watchlist = [ (xg_train,'train'), (xg_test, 'test') ]

num_round = 100
bst = xgb.train(param, xg_train, num_round, watchlist );
# get prediction
pred = bst.predict( xg_test )

print ('predicting, classification error=%f' % (sum( int(pred[i]) != test_y[i] for i in range(len(test_y))) / float(len(test_y)) ))
  #=> predicting, classification error=0.155499

# probabilityが出るので0.5以上は広島、0.5未満はそれ以外と判定されたことにする
((bst.predict( xg_test ) + 0.5).astype(np.int) == test_y).sum() / len(test_y)
  #=> 0.87757923845990216

何やら87.7%が分類に成功した。1語しか取れてないTweetが多い中でこの数字は割と良いのではないだろうか。

どんなモデルになったか確認してみる。

bst.dump_model('model.txt')

下記のようなTree上の構造が表示される。

booster[0]:
0:[f3329<-9.53674e-07] yes=1,no=2,missing=1
    1:[f4096<-9.53674e-07] yes=3,no=4,missing=3
        3:[f4993<-9.53674e-07] yes=7,no=8,missing=7
            7:[f2993<-9.53674e-07] yes=13,no=14,missing=13
                13:[f1745<-9.53674e-07] yes=21,no=22,missing=21
                    21:leaf=-0.148161
                    22:leaf=0.10137
                14:[f1518<-9.53674e-07] yes=23,no=24,missing=23
                    23:leaf=0.162712
                    24:leaf=-0.1
            8:[f3455<-9.53674e-07] yes=15,no=16,missing=15
                15:[f3612<-9.53674e-07] yes=25,no=26,missing=25
                    25:leaf=0.115789
                    26:leaf=-0.0666667
                      ・
                      ・
                      ・
                      ・

条件に対してyesならこっちに分岐、noならこっちに分岐という感じの記述。

これでも意味合いはわからなくないのだけどちょっと見づらい。

get_fscore()で重要な要素が取れるらしい。

import operator
sorted(bst.get_fscore().items(), key=operator.itemgetter(1), reverse=True)[0:10]
    #=> [('f3329', 33),
    #=>  ('f4993', 31),
    #=>  ('f4096', 25),
    #=>  ('f3455', 23),
    #=>  ('f1745', 20),
    #=>  ('f1496', 19),
    #=>  ('f2993', 18),
    #=>  ('f3496', 17),
    #=>  ('f1976', 17),
    #=>  ('f2943', 15)]

idのままだとわかりづらいので、実際に効いていた単語の一覧を出してみる。

for (f, c) in important_features:
    print( dic[int(f[1:])] )
        #=> 新井
        #=> 黒田
        #=> 石原
        #=> 月間
        #=> ルナ
        #=> ヘーゲンズ
        #=> 広島カープ
        #=> 杉山
        #=> 九里
        #=> 巨人戦

新井、黒田、石原、ルナ、ヘーゲンズは広島の中心選手。九里も最近はよく投げてる中継ぎ投手。杉山は・・・なんだろう。ちょうどTweetを取得した時期にバレンティンのバットで石原が負傷していて、ちょっと前に中日の杉山も同様の怪我をしているので、タイミング的に広島のTweetに杉山が良く出てきたとかだろうか。ここはTweetの取得期間が短い(5日間くらい)のが原因な気がする。

Tree構造をビジュアライズすることも可能。その場合はgraphvizが必要。依存ライブラリの都合でpipだけでは動かないことがあるのでcondaが楽。

$ conda install graphviz

plot_treeメソッドを利用。

xgb.plot_tree(bst)

下記のようなツリーが描画される(クリックで拡大)

tree

ツリーの一番上はf3329(新井)。新井さんはやはり強い。2段目は4096(石原)と2143(倉本)。3段目は4993(黒田)、2339(北條)、3496(杉山)。杉山を除くと順調に広島の選手の名前が並んでいる。

全球団分類問題

続いて全球団の分類。

カープのみの分類は二値分類なので、objectiveのところはbinary:logisticを利用した。今回は多値分類としてmulti:softmaxを利用する。また、ラベルは先ほどのような0/1ではなく、各チームのIDを付加する。

# チームの一覧を作る(10球団使う)
teams = dict([(f[f.index('/')+1:], i) for i, f in enumerate(files)])
    #=> {'baystars.txt': 6,
    #=>  'carp.txt': 8,
    #=>  'dragons.txt': 3,
    #=>  'fighters.txt': 1,
    #=>  'giants.txt': 10,
    #=>  'hawks.txt': 2,
    #=>  'lions.txt': 5,
    #=>  'marines.txt': 7,
    #=>  'swallows.txt': 9,
    #=>  'tigers.txt': 4}

def fileWriter(file, writer):
    # チームごとにラベルを貼る
    label = teams[file]
    with open('wakati/' + file, 'rt') as reader:
        team = file
        for line in reader:
            line = mkLine(line)
            if len(line) > 0:
                writer.write('{0} {1}\n'.format(label, line))

with open('tweets.txt', 'wt') as writer:
    for file in os.listdir('wakati'):
        fileWriter(file, writer)

これで先程はカープなら1、それ以外は0としていたラベルに、チームごとの数値が割り振られた。

こうしてできたファイルをxgboostに読み込ませてみる。

multiclassにするにあたって前回との指定パラメータの違いとして、objectiveにmulti:softmaxが指定されているのと、分類するクラスの数をあらかじめ明記しないといけないのでnum_classに10を指定する。

また、iterationを100回にしていたのを3000回に増やしているのと(だいたい3000くらいでサチった)、trainの際にverbose_eval=100を指定して、100回iterationするごとに現状のtrainの状況を伝えるverboseが出るようにしている(そうしないと毎回traceされて3000行も標準出力されることになって邪魔)。

import numpy as np
import xgboost as xgb
from sklearn.datasets import load_svmlight_file

x, y = load_svmlight_file("tweets.txt")
rnd = np.random.random(len(y))
train_x = x[rnd < 0.8]
train_y = y[rnd < 0.8]
test_x = x[rnd >= 0.8]
test_y = y[rnd >= 0.8]
xg_train = xgb.DMatrix(train_x, label=train_y)
xg_test = xgb.DMatrix(test_x, label=test_y)

# setup parameters for xgboost
param = {
    'objective':'multi:softmax',
    'eta': 0.1,
    'max_depth': 5,
    'nthread': 4,
    'num_class': 10
}

watchlist = [ (xg_train,'train'), (xg_test, 'test') ]

num_round = 3000
bst = xgb.train(param, xg_train, num_round, watchlist, verbose_eval=100)
# get prediction
pred = bst.predict( xg_test )

この結果を使ってpredictをする。前回のbinary:logitではprobabilityが出たので便宜的に0.5以上ならカープと判定する処理で確認した。

今回はmulti:softmaxでpredictの結果はclassとして出てくる。

# predictするとtweetごとの推測したclassが出てくる
bst.predict( xg_test )
  #=> array([ 0.,  6.,  5., ...,  9.,  9.,  8.], dtype=float32)

# 正答率の確認
sum(bst.predict( xg_test ) == test_y) / len(test_y)
  #=> 0.70986728216964801

正答率が70%程度。classの数が多い割にはそこそこ仕事はしてくれているっぽい。

チームごとの正答率も見てみる。

reverse_teams = dict((i, t)for t, i in teams.items())
result = bst.predict( xg_test )
for i in range(10):
    print(reverse_teams[i], sum( (result == i) & (test_y == i) ) / sum(test_y == i))
        #=> fighters.txt 0.704081632653
        #=> hawks.txt 0.647727272727
        #=> dragons.txt 0.752453653217
        #=> tigers.txt 0.77308707124
        #=> lions.txt 0.725328947368
        #=> baystars.txt 0.6875
        #=> marines.txt 0.700900900901
        #=> carp.txt 0.687640449438
        #=> swallows.txt 0.647925033467
        #=> giants.txt 0.748659003831

ホークスの0.647が最小で、ジャイアンツの0.748が最大。そこまで大きく苦手な球団はない模様。たいしてパラメータ調整もせずにざっくり実行した割にはよく出来た子だ。

重要語も出してみる。

for (f, c) in important_features:
    print( dic[int(f[1:])] )
        #=> 栗山
        #=> ファイターズ
        #=> 巨人戦
        #=> 伊東
        #=> ライオンズ
        #=> 中日ドラゴンズ
        #=> 阪神タイガース
        #=> ゴメス
        #=> 福留
        #=> 押し出し

1位は栗山監督だった。

生成したモデルのセーブとロード

最後に、今回生成したモデルを保存し、再読込してみる。

まずは保存。保存されるデータはバイナリ。C++側でdmlcのioの機能を使って結果をシリアライズして出力してるっぽい。

bst.save_model('model.bin')

保存したモデルをloadしてみる。

bst2 = xgb.Booster({'nthread':4})
bst2.load_model("model.bin")
bst2.predict( xg_test )
    #=> array([ 0.,  6.,  5., ...,  9.,  9.,  8.], dtype=float32)

ちゃんと先ほどと同じ実行結果になった。

ちなみに今回出力したモデルは57MBとそれなりのサイズになっており、loadする際もそれなりに実行時間がかかった。