web-dev-qa-db-fra.com

différence entre StratifiedKFold et StratifiedShuffleSplit dans sklearn

À partir du titre, je me demande quelle est la différence entre

StratifiedKFold avec le paramètre shuffle = True

StratifiedKFold(n_splits=10, shuffle=True, random_state=0)

et

StratifiedShuffleSplit

StratifiedShuffleSplit(n_splits=10, test_size=’default’, train_size=None, random_state=0)

et quel est l'avantage d'utiliser StratifiedShuffleSplit

32
gabboshow

Dans KFolds, chaque ensemble de tests ne doit pas se chevaucher, même en mode aléatoire. Avec KFolds et shuffle, les données sont mélangées une fois au début, puis divisées en nombre de divisions souhaitées. Les données de test sont toujours l'une des divisions, les données de train sont les autres.

Dans ShuffleSplit, les données sont brassées à chaque fois, puis divisées. Cela signifie que les ensembles de tests peuvent se chevaucher entre les divisions.

Voir ce bloc pour un exemple de la différence. Notez le chevauchement des éléments dans les ensembles de tests pour ShuffleSplit.

splits = 5

tx = range(10)
ty = [0] * 5 + [1] * 5

from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold
from sklearn import datasets

kfold = StratifiedKFold(n_splits=splits, shuffle=True, random_state=42)
shufflesplit = StratifiedShuffleSplit(n_splits=splits, random_state=42, test_size=2)

print("KFold")
for train_index, test_index in kfold.split(tx, ty):
    print("TRAIN:", train_index, "TEST:", test_index)

print("Shuffle Split")
for train_index, test_index in shufflesplit.split(tx, ty):
    print("TRAIN:", train_index, "TEST:", test_index)

Sortie:

KFold
TRAIN: [0 2 3 4 5 6 7 9] TEST: [1 8]
TRAIN: [0 1 2 3 5 7 8 9] TEST: [4 6]
TRAIN: [0 1 3 4 5 6 8 9] TEST: [2 7]
TRAIN: [1 2 3 4 6 7 8 9] TEST: [0 5]
TRAIN: [0 1 2 4 5 6 7 8] TEST: [3 9]
Shuffle Split
TRAIN: [8 4 1 0 6 5 7 2] TEST: [3 9]
TRAIN: [7 0 3 9 4 5 1 6] TEST: [8 2]
TRAIN: [1 2 5 6 4 8 9 0] TEST: [3 7]
TRAIN: [4 6 7 8 3 5 1 2] TEST: [9 0]
TRAIN: [7 2 6 5 4 3 0 9] TEST: [1 8]

Pour ce qui est de savoir quand les utiliser, j’ai tendance à utiliser KFold pour toute validation croisée et j’utilise ShuffleSplit avec une division de 2 pour les divisions de mon train/jeu d’essais. Mais je suis sûr qu'il existe d'autres cas d'utilisation pour les deux.

43
Ken Syme

@ Ken Syme a déjà une très bonne réponse. Je veux juste ajouter quelque chose.

  • StratifiedKFold est une variante de KFold. Tout d'abord, StratifiedKFold mélange vos données, puis les divise en n_splits parties et Fait. Maintenant, il utilisera chaque partie comme un ensemble de test. Notez que uniquement et mélange toujours les données une fois avant de les séparer.

Avec shuffle = True, les données sont mélangées par votre random_state. Sinon, les données sont mélangées par np.random (par défaut). Par exemple, avec n_splits = 4, et vos données ont 3 classes (label) pour y (variable dépendante). 4 ensembles de test couvrent toutes les données sans aucun chevauchement.

enter image description here

  • D'autre part, StratifiedShuffleSplit est une variante de ShuffleSplit. Tout d'abord, StratifiedShuffleSplit mélange vos données, puis les divise en n_splits les pièces. Cependant, ce n'est pas encore fait. Après cette étape, StratifiedShuffleSplit choisit une partie à utiliser comme jeu de test. Ensuite, il répète le même processus n_splits - 1 autres fois, pour obtenir n_splits - 1 autres ensembles de test. Regardez l'image ci-dessous, avec les mêmes données, mais cette fois, les 4 jeux de tests ne couvrent pas toutes les données, c'est-à-dire qu'il y a des chevauchements entre les jeux de tests.

enter image description here

Donc, la différence est que StratifiedKFold mélange et scinde une seule fois. Par conséquent, les ensembles de tests ne se chevauchent pas , tandis que StratifiedShuffleSplit mélange à chaque fois avant de se séparer, et il se sépare n_splits _ fois, les ensembles de tests peuvent se chevaucher .

  • Remarque : les deux méthodes utilisent le "pli stratifié" (pourquoi "stratifié" apparaît dans les deux noms). Cela signifie que chaque partie conserve le même pourcentage d'échantillons de chaque classe (étiquette) que les données d'origine. Vous pouvez en lire plus à cross_validation documents
26
Catbuilts