工作と競馬2

電子工作、プログラミング、木工といった工作の記録記事、競馬に関する考察記事を掲載するブログ

決定木分析 自分用メモ

参考

pythondatascience.plavox.info

qiita.com

必要パッケージインストール

pip install scikit-learn
pip install dtreeviz

dtreevizにはGraphvizが必要なので

https://graphviz.gitlab.io/_pages/Download/Download_windows.html

ユーザー環境変数PATHに、Graphvizの実行ファイル置き場を追加。

C:\Program Files (x86)\Graphviz2.38\bin

インポート

import pandas as pd
from sklearn import tree
from dtreeviz.trees import dtreeviz

モデル作成と学習

分類器の場合

pandasでcsvファイルを読み出して学習データとする想定。 y_trainは、カテゴリ名に対して0以上の整数を割り当てた列。pd.DataFrame().get_dummies()ではなく、pd.Series().map()で作成すればよい。

# 学習データ
x_train = 説明変数の列たち
y_train = 目的変数

# 分類器作成
clf = tree.DecisionTreeClassifier(max_depth=深さ)

# 学習
clf = clf.fit(x_train, y_train)

可視化

dtreevizを使う。

viz = dtreeviz(
    clf,
    x_train, 
    y_train,
    target_name=y_trainの列名,
    feature_names=x_train.columns,
    class_names=y_trainの数値に対応するカテゴリ名の配列
) 

viz.view()