web-dev-qa-db-fra.com

ValueError: n_splits = 10 ne peut pas être supérieur au nombre de membres dans chaque classe

J'essaie d'exécuter le code suivant:

from sklearn.model_selection import StratifiedKFold 
X = ["hey", "join now", "hello", "join today", "join us now", "not today", "join this trial", " hey hey", " no", "hola", "bye", "join today", "no","join join"]
y = ["n", "r", "n", "r", "r", "n", "n", "n", "n", "r", "n", "n", "n", "r"]

skf = StratifiedKFold(n_splits=10)

for train, test in skf.split(X,y):  
    print("%s %s" % (train,test))

Mais j'obtiens l'erreur suivante:

ValueError: n_splits=10 cannot be greater than the number of members in each class.

J'ai regardé ici erreur scikit-learn: la classe la moins peuplée de y n'a qu'un seul membre mais je ne suis toujours pas vraiment sûr de ce qui ne va pas avec mon code.

Mes listes ont toutes deux une longueur de 14 print(len(X))print(len(y)).

Une partie de ma confusion est que je ne suis pas sûr de la définition d'un members et de ce qu'est un class dans ce contexte.

Questions: Comment puis-je corriger l'erreur? Qu'est-ce qu'un membre? Qu'est-ce qu'une classe? (dans ce contexte)

7
SFC

La stratification signifie garder le rapport de chaque classe dans chaque pli. Donc, si votre jeu de données d'origine a 3 classes dans le rapport de 60%, 20% et 20%, la stratification essaiera de garder ce rapport dans chaque pli.

Dans ton cas,

X = ["hey", "join now", "hello", "join today", "join us now", "not today",
     "join this trial", " hey hey", " no", "hola", "bye", "join today", 
     "no","join join"]
y = ["n", "r", "n", "r", "r", "n", "n", "n", "n", "y", "n", "n", "n", "y"]

Vous avez un total de 14 échantillons (membres) avec la distribution:

class    number of members         percentage
 'n'        9                        64
 'r'        3                        22
 'y'        2                        14

Donc StratifiedKFold essaiera de garder ce ratio dans chaque pli. Vous avez maintenant spécifié 10 plis (n_splits). Cela signifie donc en un seul pli, pour que la classe "y" maintienne le rapport, au moins 2/10 = 0,2 membres. Mais nous ne pouvons pas donner moins de 1 membre (échantillon), c'est pourquoi cela génère une erreur.

Si au lieu de n_splits=10, vous avez défini n_splits=2, alors cela aurait fonctionné, car le nombre de membres pour 'y' sera de 2/2 = 1. Pour n_splits = 10 pour fonctionner correctement, vous devez avoir au moins 10 échantillons pour chacune de vos classes.

11
Vivek Kumar