参考
必要パッケージインストール
pip install scikit-learn pip install dtreeviz
dtreevizにはGraphvizが必要なので
https://graphviz.gitlab.io/_pages/Download/Download_windows.html
ユーザー環境変数PATHに、Graphvizの実行ファイル置き場を追加。
C:\Program Files (x86)\Graphviz2.38\bin
インポート
import pandas as pd from sklearn import tree from dtreeviz.trees import dtreeviz
モデル作成と学習
分類器の場合
pandasでcsvファイルを読み出して学習データとする想定。 y_trainは、カテゴリ名に対して0以上の整数を割り当てた列。pd.DataFrame().get_dummies()ではなく、pd.Series().map()で作成すればよい。
# 学習データ x_train = 説明変数の列たち y_train = 目的変数 # 分類器作成 clf = tree.DecisionTreeClassifier(max_depth=深さ) # 学習 clf = clf.fit(x_train, y_train)
可視化
dtreevizを使う。
viz = dtreeviz( clf, x_train, y_train, target_name=y_trainの列名, feature_names=x_train.columns, class_names=y_trainの数値に対応するカテゴリ名の配列 ) viz.view()