J'ai deux problèmes avec la compréhension du résultat de l'arbre de décision de scikit-learn. Par exemple, c'est l'un de mes arbres de décision:
Ma question est la suivante: comment utiliser l'arbre?
La première question est la suivante: si un échantillon remplit la condition, il passe à la branche GAUCHE (s'il existe), sinon il passe DROITE . Dans mon cas, si un échantillon avec X [7]> 63521.3984. Ensuite, l'échantillon ira dans la boîte verte. Correct?
La deuxième question est la suivante: lorsqu'un échantillon atteint le nœud feuille, comment savoir à quelle catégorie il appartient? Dans cet exemple, j'ai trois catégories à classer. Dans la case rouge, il y a respectivement 91, 212 et 113 échantillons satisfaits à la condition. Mais comment décider de la catégorie? Je sais qu'il existe une fonction clf.predict (sample) pour indiquer la catégorie. Puis-je faire cela à partir du graphique ??? Merci beaucoup.
La ligne value
dans chaque case vous indique le nombre d'échantillons de ce nœud dans chaque catégorie, dans l'ordre. C'est pourquoi, dans chaque case, les nombres dans value
s'additionnent au nombre affiché dans sample
. Par exemple, dans votre case rouge, 91 + 212 + 113 = 416. Cela signifie donc que si vous atteignez ce nœud, il y avait 91 points de données dans la catégorie 1, 212 dans la catégorie 2 et 113 dans la catégorie 3.
Si vous deviez prédire le résultat d'un nouveau point de données qui a atteint cette feuille dans l'arbre de décision, vous prédiriez la catégorie 2, car c'est la catégorie la plus courante pour les échantillons à ce nœud.
Première question: Oui, votre logique est correcte. Le nœud gauche est vrai et le nœud droit est faux. Cela peut être contre-intuitif; true peut correspondre à un échantillon plus petit.
Deuxième question: Ce problème est mieux résolu en visualisant l'arbre sous forme de graphique avec pydotplus. L'attribut 'class_names' de tree.export_graphviz () ajoutera une déclaration de classe à la classe majoritaire de chaque nœud. Le code est exécuté dans un bloc-notes iPython.
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf2 = tree.DecisionTreeClassifier()
clf2 = clf2.fit(iris.data, iris.target)
with open("iris.dot", 'w') as f:
f = tree.export_graphviz(clf, out_file=f)
import os
os.unlink('iris.dot')
import pydotplus
dot_data = tree.export_graphviz(clf2, out_file=None)
graph2 = pydotplus.graph_from_dot_data(dot_data)
graph2.write_pdf("iris.pdf")
from IPython.display import Image
dot_data = tree.export_graphviz(clf2, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True, # leaves_parallel=True,
special_characters=True)
graph2 = pydotplus.graph_from_dot_data(dot_data)
## Color of nodes
nodes = graph2.get_node_list()
for node in nodes:
if node.get_label():
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')];
color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],}
values = color[values.index(max(values))]; # print(values)
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color)
node.set_fillcolor(color )
#
Image(graph2.create_png() )
En ce qui concerne la détermination de la classe à la feuille, votre exemple n'a pas de feuilles avec une seule classe, comme le fait l'ensemble de données iris. Ceci est courant et peut nécessiter un ajustement excessif du modèle pour atteindre un tel résultat. Une distribution discrète des classes est le meilleur résultat pour de nombreux modèles à validation croisée.
Profitez du code!
Selon le livre "Learning scikit-learn: Machine Learning in Python", l'arbre de décision représente une série de décisions basées sur les données de formation.
! ( http://i.imgur.com/vM9fJLy.png )
Pour classer une instance, nous devons répondre à la question à chaque nœud. Par exemple, le sexe est-il <= 0,5? (parlons-nous d'une femme?). Si la réponse est oui, vous allez au nœud enfant gauche dans l'arborescence; sinon vous allez au bon nœud enfant . Vous continuez à répondre aux questions (était-elle dans la troisième classe?, Était-elle dans la première classe?, Et avait-elle moins de 13 ans?), Jusqu'à ce que vous atteigniez une feuille. Lorsque vous y êtes, la prédiction correspond à la classe cible qui a le plus d'instances .
Ajoutez feature_names = X.columns à tree.export_graphviz où X correspond aux données d'entraînement.
Mon code est le suivant
with open("lectureGini.txt", "w") as f:
f = tree.export_graphviz(lectureGini, out_file=f,feature_names=X.columns)
# copy contents of file LectureGini.txt into WebGraphviz - http://webgraphviz.com/
lectureGini est la sortie de mon DecisionTreeClassifier
C'est une méthode simple que j'ai découverte qui pourrait être ajoutée à tous les exemples Web de l'indice Gini que j'avais recherchés. Tous les exemples Web expliquaient très bien la méthode, mais aucun ne montrait comment trouver les catégories. Je n'ai pas encore installé Graphviz, donc j'exporte un fichier texte depuis jupyter et je copie le texte dans le Webgraphwiz