web-dev-qa-db-fra.com

Paramètre "stratify" de la méthode "train_test_split" (scikit Learn)

J'essaie d'utiliser train_test_split du paquet scikit Learn, mais je rencontre des problèmes avec le paramètre stratify. Voici le code:

from sklearn import cross_validation, datasets 

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

Cependant, je continue à avoir le problème suivant:

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

Est-ce que quelqu'un a une idée de ce qui se passe? Vous trouverez ci-dessous la documentation de la fonction.

[...]

stratify: semblable à un tableau ou None (la valeur par défaut est None)

Sinon, les données sont divisées de manière stratifiée, en utilisant ceci comme tableau d'étiquettes.

Nouveauté de la version 0.17: fractionnement par stratification

[...]

58
Daneel Olivaw

Scikit-Learn vous dit simplement qu'il ne reconnaît pas l'argument "stratifier", pas que vous ne l'utilisez pas correctement. En effet, le paramètre a été ajouté à la version 0.17, comme indiqué dans la documentation que vous avez citée.

Il suffit donc de mettre à jour Scikit-Learn.

44
Borja

Ce paramètre stratify effectue une scission de sorte que la proportion de valeurs dans l'échantillon produit soit identique à la proportion de valeurs fournie au paramètre stratify.

Par exemple, si la variable y est une variable catégorique binaire avec les valeurs 0 et 1 et qu'il existe 25% de zéros et 75% de valeurs, stratify=y s'assurera que votre répartition aléatoire contient 25% de 0 'et 75% de 1'.

189
Fazzolini

Pour mon futur moi qui vient ici via Google:

train_test_split est maintenant dans _model_selection_, d'où:

_from sklearn.model_selection import train_test_split

# given:
# features: xs
# ground truth: ys

x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                    test_size=0.33,
                                                    random_state=0,
                                                    stratify=ys)
_

est la façon de l'utiliser. Le réglage de _random_state_ est souhaitable pour la reproductibilité.

36
Martin Thoma

Dans ce contexte, la stratification signifie que la méthode train_test_split renvoie des sous-ensembles d'apprentissage et de test présentant les mêmes proportions d'étiquettes de classe que l'ensemble de données en entrée.

8
X. Wang

Essayez d’exécuter ce code, cela "fonctionne":

from sklearn import cross_validation, datasets 

iris = datasets.load_iris()

X = iris.data[:,:2]
y = iris.target

x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)

y_test

array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
       1, 2, 1, 1, 0, 2, 1])
3
Sergey Bushmanov