J'essaie d'exécuter le code suivant:
from sklearn.model_selection import StratifiedKFold
X = ["hey", "join now", "hello", "join today", "join us now", "not today", "join this trial", " hey hey", " no", "hola", "bye", "join today", "no","join join"]
y = ["n", "r", "n", "r", "r", "n", "n", "n", "n", "r", "n", "n", "n", "r"]
skf = StratifiedKFold(n_splits=10)
for train, test in skf.split(X,y):
print("%s %s" % (train,test))
Mais j'obtiens l'erreur suivante:
ValueError: n_splits=10 cannot be greater than the number of members in each class.
J'ai regardé ici erreur scikit-learn: la classe la moins peuplée de y n'a qu'un seul membre mais je ne suis toujours pas vraiment sûr de ce qui ne va pas avec mon code.
Mes listes ont toutes deux une longueur de 14 print(len(X))
print(len(y))
.
Une partie de ma confusion est que je ne suis pas sûr de la définition d'un members
et de ce qu'est un class
dans ce contexte.
Questions: Comment puis-je corriger l'erreur? Qu'est-ce qu'un membre? Qu'est-ce qu'une classe? (dans ce contexte)
La stratification signifie garder le rapport de chaque classe dans chaque pli. Donc, si votre jeu de données d'origine a 3 classes dans le rapport de 60%, 20% et 20%, la stratification essaiera de garder ce rapport dans chaque pli.
Dans ton cas,
X = ["hey", "join now", "hello", "join today", "join us now", "not today",
"join this trial", " hey hey", " no", "hola", "bye", "join today",
"no","join join"]
y = ["n", "r", "n", "r", "r", "n", "n", "n", "n", "y", "n", "n", "n", "y"]
Vous avez un total de 14 échantillons (membres) avec la distribution:
class number of members percentage
'n' 9 64
'r' 3 22
'y' 2 14
Donc StratifiedKFold essaiera de garder ce ratio dans chaque pli. Vous avez maintenant spécifié 10 plis (n_splits). Cela signifie donc en un seul pli, pour que la classe "y" maintienne le rapport, au moins 2/10 = 0,2 membres. Mais nous ne pouvons pas donner moins de 1 membre (échantillon), c'est pourquoi cela génère une erreur.
Si au lieu de n_splits=10
, vous avez défini n_splits=2
, alors cela aurait fonctionné, car le nombre de membres pour 'y' sera de 2/2 = 1. Pour n_splits = 10
pour fonctionner correctement, vous devez avoir au moins 10 échantillons pour chacune de vos classes.