Puis-je extraire les règles de décision sous-jacentes (ou "chemins de décision") d'un arbre formé dans un arbre de décision sous forme de liste textuelle?
Quelque chose comme:
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
Merci de votre aide.
Je crois que cette réponse est plus correcte que les autres réponses ici:
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print "{}if {} <= {}:".format(indent, name, threshold)
recurse(tree_.children_left[node], depth + 1)
print "{}else: # if {} > {}".format(indent, name, threshold)
recurse(tree_.children_right[node], depth + 1)
else:
print "{}return {}".format(indent, tree_.value[node])
recurse(0, 1)
Cela affiche une fonction Python valide. Voici un exemple de sortie pour une arborescence qui tente de renvoyer son entrée, un nombre compris entre 0 et 10.
def tree(f0):
if f0 <= 6.0:
if f0 <= 1.5:
return [[ 0.]]
else: # if f0 > 1.5
if f0 <= 4.5:
if f0 <= 3.5:
return [[ 3.]]
else: # if f0 > 3.5
return [[ 4.]]
else: # if f0 > 4.5
return [[ 5.]]
else: # if f0 > 6.0
if f0 <= 8.5:
if f0 <= 7.5:
return [[ 7.]]
else: # if f0 > 7.5
return [[ 8.]]
else: # if f0 > 8.5
return [[ 9.]]
Voici quelques points d'achoppement que je vois dans d'autres réponses:
tree_.threshold == -2
pour décider si un nœud est une feuille n'est pas une bonne idée. Et si c'est un vrai noeud de décision avec un seuil de -2? Au lieu de cela, vous devriez regarder tree.feature
ou tree.children_*
.features = [feature_names[i] for i in tree_.feature]
plante avec ma version de sklearn, car certaines valeurs de tree.tree_.feature
sont -2 (spécifiquement pour les nœuds terminaux).J'ai créé ma propre fonction pour extraire les règles des arbres de décision créés par sklearn:
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})
# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)
Cette fonction commence par les nœuds (identifiés par -1 dans les tableaux enfants), puis recherche les parents de manière récursive. J'appelle cela une "lignée" de nœud. En cours de route, je récupère les valeurs nécessaires pour créer la logique if/then/else SAS:
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = 'l'
else:
parent = np.where(right == child)[0].item()
split = 'r'
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
for child in idx:
for node in recurse(left, right, child):
print node
Les ensembles de nuplets ci-dessous contiennent tout ce dont j'ai besoin pour créer SAS instructions if/then/else. Je n'aime pas utiliser les blocs do
dans SAS, c'est pourquoi je crée une logique décrivant le chemin complet d'un nœud. L'entier unique après les nuplets est l'ID du nœud terminal dans un chemin. Tous les tuples précédents se combinent pour créer ce nœud.
In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6
J'ai modifié le code soumis par Zelazny7 pour imprimer un pseudocode:
def get_code(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node):
if (threshold[node] != -2):
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node])
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node])
print "}"
else:
print "return " + str(value[node])
recurse(left, right, threshold, features, 0)
si vous appelez get_code(dt, df.columns)
sur le même exemple, vous obtiendrez:
if ( col1 <= 0.5 ) {
return [[ 1. 0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0. 1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1. 0.]]
} else {
return [[ 0. 1.]]
}
}
}
Il existe une nouvelle méthode DecisionTreeClassifier
, decision_path
, dans la version 0.18.0 . Les développeurs fournissent une vaste (bien documentée) procédure pas à pas .
La première section de code de la procédure pas à pas qui imprime la structure arborescente semble être OK. Cependant, j'ai modifié le code dans la deuxième section pour interroger un échantillon. Mes changements notés avec # <--
Edit Les modifications marquées par # <--
dans le code ci-dessous ont depuis été mises à jour dans le lien pas à pas après que les erreurs ont été signalées dans les demandes d'extraction # 8653 et # 10951 . C'est beaucoup plus facile à suivre maintenant.
sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
node_indicator.indptr[sample_id + 1]]
print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:
if leave_id[sample_id] == node_id: # <-- changed != to ==
#continue # <-- comment out
print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--
else: # < -- added else to iterate through decision nodes
if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
threshold_sign = "<="
else:
threshold_sign = ">"
print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
% (node_id,
sample_id,
feature[node_id],
X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
threshold_sign,
threshold[node_id]))
Rules used to predict sample 0:
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here
Changez le sample_id
pour voir les chemins de décision pour les autres exemples. Je n'ai pas interrogé les développeurs à propos de ces modifications, je me suis simplement montré plus intuitif lorsque j'ai suivi l'exemple.
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()
Vous pouvez voir un arbre de digraphe. Alors, clf.tree_.feature
et clf.tree_.value
sont des tableaux de noeuds séparant les entités et des valeurs de noeuds Vous pouvez vous référer à plus de détails à partir de cette source github .
Les codes ci-dessous sont mon approche sous anaconda python 2.7 plus un nom de paquetage "pydot-ng" pour créer un fichier PDF avec des règles de décision. J'espère que c'est utile.
from sklearn import tree
clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)
feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)
def output_pdf(clf_, name):
from sklearn import tree
from sklearn.externals.six import StringIO
import pydot_ng as pydot
dot_data = StringIO()
tree.export_graphviz(clf_, out_file=dot_data,
feature_names=feature_names,
class_names=class_name,
filled=True, rounded=True,
special_characters=True,
node_ids=1,)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("%s.pdf"%name)
output_pdf(clf_, name='filename%s'%n)
Juste parce que tout le monde a été si utile, je vais juste ajouter une modification aux belles solutions de Zelazny7 et de Daniele. Celui-ci est destiné à Python 2.7, avec des onglets pour le rendre plus lisible:
def get_code(tree, feature_names, tabdepth=0):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node, tabdepth=0):
if (threshold[node] != -2):
print '\t' * tabdepth,
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node], tabdepth+1)
print '\t' * tabdepth,
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node], tabdepth+1)
print '\t' * tabdepth,
print "}"
else:
print '\t' * tabdepth,
print "return " + str(value[node])
recurse(left, right, threshold, features, 0)
Ceci s'appuie sur la réponse de @paulkernfeld. Si vous avez un cadre de données X avec vos entités et un cadre de données cible y avec vos réponses et que vous souhaitez avoir une idée de la valeur y indiquée dans quel nœud (et également de le tracer en conséquence), vous pouvez procéder comme suit:
def tree_to_code(tree, feature_names):
codelines = []
codelines.append('def get_cat(X_tmp):\n')
codelines.append(' catout = []\n')
codelines.append(' for codelines in range(0,X_tmp.shape[0]):\n')
codelines.append(' Xin = X_tmp.iloc[codelines]\n')
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
#print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
recurse(tree_.children_left[node], depth + 1)
codelines.append( '{}else: # if Xin["{}"] > {}\n'.format(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
codelines.append( '{}mycat = {}\n'.format(indent, node))
recurse(0, 1)
codelines.append(' catout.append(mycat)\n')
codelines.append(' return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
codelines.append('node_ids = get_cat(X)\n')
return codelines
mycode = tree_to_code(clf,X.columns.values)
# now execute the function and obtain the dataframe with all nodes
exec(''.join(mycode))
node_ids = [int(x[0]) for x in node_ids.values]
node_ids2 = pd.DataFrame(node_ids)
print('make plot')
import matplotlib.cm as cm
colors = cm.Rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
#plt.figure(figsize=cm2inch(24, 21))
for i in list(set(node_ids)):
plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))
mytitle = ['y colored by node']
plt.title(mytitle ,fontsize=14)
plt.xlabel('my xlabel')
plt.ylabel(tagname)
plt.xticks(rotation=70)
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
plt.tight_layout()
plt.show()
plt.close
pas la version la plus élégante mais elle fait le travail ...
Scikit learn a introduit une nouvelle méthode délicieuse appelée export_text
dans la version 0.21 (mai 2019) pour extraire les règles d'un arbre. Documentation ici . Il n'est plus nécessaire de créer une fonction personnalisée.
Une fois que vous avez adapté votre modèle, il vous suffit de deux lignes de code. Commencez par importer export_text
:
from sklearn.tree.export import export_text
Deuxièmement, créez un objet qui contiendra vos règles. Pour rendre les règles plus lisibles, utilisez l'argument feature_names
et transmettez une liste de vos noms d'entités. Par exemple, si votre modèle s'appelle model
et que vos entités sont nommées dans un cadre de données appelé X_train
, vous pouvez créer un objet appelé tree_rules
:
tree_rules = export_text(model, feature_names=list(X_train))
Ensuite, imprimez ou enregistrez tree_rules
. Votre sortie ressemblera à ceci:
|--- Age <= 0.63
| |--- EstimatedSalary <= 0.61
| | |--- Age <= -0.16
| | | |--- class: 0
| | |--- Age > -0.16
| | | |--- EstimatedSalary <= -0.06
| | | | |--- class: 0
| | | |--- EstimatedSalary > -0.06
| | | | |--- EstimatedSalary <= 0.40
| | | | | |--- EstimatedSalary <= 0.03
| | | | | | |--- class: 1
Je suis passé par là, mais j'avais besoin que les règles soient écrites dans ce format
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
J'ai donc adapté la réponse de @paulkernfeld (merci) que vous pouvez personnaliser en fonction de vos besoins.
def tree_to_code(tree, feature_names, Y):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
pathto=dict()
global k
k = 0
def recurse(node, depth, parent):
global k
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
s= "{} <= {} ".format( name, threshold, node )
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+' & ' +s
recurse(tree_.children_left[node], depth + 1, node)
s="{} > {}".format( name, threshold)
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+' & ' +s
recurse(tree_.children_right[node], depth + 1, node)
else:
k=k+1
print(k,')',pathto[parent], tree_.value[node])
recurse(0, 1, 0)
Voici une fonction d'impression des règles d'un arbre de décision scikit-learn sous python 3 et avec des décalages pour les blocs conditionnels afin de rendre la structure plus lisible:
def print_decision_tree(tree, feature_names=None, offset_unit=' '):
'''Plots textual representation of rules of a decision tree
tree: scikit-learn representation of tree
feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
offset_unit: a string of offset of the conditional block'''
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
features = ['f%d'%i for i in tree.tree_.feature]
else:
features = [feature_names[i] for i in tree.tree_.feature]
def recurse(left, right, threshold, features, node, depth=0):
offset = offset_unit*depth
if (threshold[node] != -2):
print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features,left[node],depth+1)
print(offset+"} else {")
if right[node] != -1:
recurse (left, right, threshold, features,right[node],depth+1)
print(offset+"}")
else:
print(offset+"return " + str(value[node]))
recurse(left, right, threshold, features, 0,0)
Apparemment, il y a longtemps, quelqu'un a déjà décidé d'essayer d'ajouter la fonction suivante aux fonctions d'exportation de l'arbre officiel de scikit (qui ne prend en charge que export_graphviz)
def export_dict(tree, feature_names=None, max_depth=None) :
"""Export a decision tree in dict format.
Voici son engagement complet:
Pas tout à fait sûr de ce qui est arrivé à ce commentaire. Mais vous pouvez aussi essayer d'utiliser cette fonction.
Je pense que cela justifie une demande sérieuse de documentation auprès des personnes compétentes de scikit-learn afin de documenter correctement l'API sklearn.tree.Tree
, qui est la structure arborescente sous-jacente que DecisionTreeClassifier
expose sous son attribut tree_
.
Vous pouvez également le rendre plus informatif en le distinguant à quelle classe il appartient ou même en mentionnant sa valeur de sortie.
def print_decision_tree(tree, feature_names, offset_unit=' '):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
features = ['f%d'%i for i in tree.tree_.feature]
else:
features = [feature_names[i] for i in tree.tree_.feature]
def recurse(left, right, threshold, features, node, depth=0):
offset = offset_unit*depth
if (threshold[node] != -2):
print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features,left[node],depth+1)
print(offset+"} else {")
if right[node] != -1:
recurse (left, right, threshold, features,right[node],depth+1)
print(offset+"}")
else:
#print(offset,value[node])
#To remove values from node
temp=str(value[node])
mid=len(temp)//2
tempx=[]
tempy=[]
cnt=0
for i in temp:
if cnt<=mid:
tempx.append(i)
cnt+=1
else:
tempy.append(i)
cnt+=1
val_yes=[]
val_no=[]
res=[]
for j in tempx:
if j=="[" or j=="]" or j=="." or j==" ":
res.append(j)
else:
val_no.append(j)
for j in tempy:
if j=="[" or j=="]" or j=="." or j==" ":
res.append(j)
else:
val_yes.append(j)
val_yes = int("".join(map(str, val_yes)))
val_no = int("".join(map(str, val_no)))
if val_yes>val_no:
print(offset,'\033[1m',"YES")
print('\033[0m')
Elif val_no>val_yes:
print(offset,'\033[1m',"NO")
print('\033[0m')
else:
print(offset,'\033[1m',"Tie")
print('\033[0m')
recurse(left, right, threshold, features, 0,0)
Code modifié de Zelazny7 pour extraire le code SQL de l’arbre de décision.
# SQL from decision tree
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
le='<='
g ='>'
# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = 'l'
else:
parent = np.where(right == child)[0].item()
split = 'r'
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
print 'case '
for j,child in enumerate(idx):
clause=' when '
for node in recurse(left, right, child):
if len(str(node))<3:
continue
i=node
if i[1]=='l': sign=le
else: sign=g
clause=clause+i[3]+sign+str(i[2])+' and '
clause=clause[:-4]+' then '+str(j)
print clause
print 'else 99 end as clusters'
Voici un moyen de traduire tout l'arbre en une seule expression python (pas nécessairement trop lisible par l'homme) à l'aide de la bibliothèque SKompiler :
from skompiler import skompile
skompile(dtree.predict).to('python/code')
Utilisez juste la fonction de sklearn.tree comme ceci
from sklearn.tree import export_graphviz
export_graphviz(tree,
out_file = "tree.dot",
feature_names = tree.columns) //or just ["petal length", "petal width"]
Ensuite, recherchez dans le dossier de votre projet le fichier tree.dot, copiez-le TOUT le contenu et collez-le ici http://www.webgraphviz.com/ et générez votre graphique :)