Que sont y_true et y_pred lors de la création d'une métrique personnalisée dans Keras?
Je souhaite implémenter ma métrique personnalisée dans Keras. Selon la documentation, ma métrique personnalisée doit être définie comme une fonction qui prend en entrée deux tenseurs, y_pred
et y_true
, et renvoie une valeur de tenseur unique.
Cependant, je ne comprends pas exactement ce que contiendront ces tenseurs y_pred
et y_true
lorsque l'optimisation est en cours d'exécution. Est-ce juste un point de données? Est-ce le lot entier? Toute l'époque (probablement pas)? Existe-t-il un moyen d'obtenir les formes de ces tenseurs?
Quelqu'un peut-il indiquer un endroit de confiance où je peux obtenir cette information? Toute aide serait appréciée. Pas sûr que ce soit pertinent, mais j'utilise le backend TensorFlow.
Ce que j’ai essayé jusqu’à présent pour répondre à cette question:
- Vérification de la documentation des métriques Keras (aucune explication sur la nature de ces tenseurs).
- Vérifier le code source pour les métriques de Keras et essayer de comprendre ces tenseurs en examinant l'implémentation de Keras pour d'autres métriques (Cela semble suggérer que
y_true
ety_pred
possèdent les étiquettes d'un lot entier, mais je ne suis pas sûr ). - Lecture de ces questions de stackoverflow: 1 , 2 , 3 et d’autres (aucune ne répond à ma question, la plupart sont centrées sur le PO ne comprenant pas clairement la différence entre un tenseur et les valeurs calculées avec ce tenseur pendant la séance).
- Imprimer les valeurs de
y_true
ety_pred
lors de l'optimisation, en définissant une métrique comme celle-ci:
def test_metric(y_true, y_pred):
y_true = K.print_tensor(y_true)
y_pred = K.print_tensor(y_pred)
return y_true - y_pred
(Malheureusement, ils n'impriment rien pendant l'optimisation).
y_true et y_pred
Le tenseur y_true
correspond aux vraies données (ou cible, vérité du terrain) que vous transmettez à la méthode d'ajustement.
C'est une conversion du numpy array y_train
en un tenseur.
Le tenseur y_pred
correspond aux données prévues (calculées, sorties) par votre modèle.
y_true
et y_pred
ont toujours la même forme, toujours.
La forme de y_true
Il contient un lot entier. Sa première dimension est toujours la taille du lot et doit exister, même si le lot ne contient qu'un seul élément.
Il existe deux manières très simples de trouver la forme de y_true
:
- vérifiez vos données vraies/cibles:
print(Y_train.shape)
- vérifier votre
model.summary()
et voir la dernière sortie
Mais sa première dimension sera la taille du lot.
Ainsi, si votre dernier calque génère (None, 1)
, la forme de y_true
est (batch, 1)
. Si la dernière couche génère (None, 200,200, 3)
, alors y_true
sera (batch, 200,200,3)
.
Mesures personnalisées et fonctions de perte
Malheureusement, l’impression de métriques personnalisées ne révélera pas leur contenu .. Vous pouvez voir leurs formes avec print(K.int_shape(y_pred))
, par exemple.
Rappelez-vous que ces bibliothèques "compilent d'abord un graphique", puis "l'exécutent avec des données". Lorsque vous définissez votre perte, vous êtes dans la phase de compilation et demander des données nécessite l'exécution du modèle.
Mais même si le résultat de votre métrique est multidimensionnel, keras trouvera automatiquement le moyen de générer un seul scalaire pour cette métrique. (Vous ne savez pas quelle est l'opération, mais très probablement une K.mean()
cachée sous la table).
Sources. Une fois que vous vous êtes habitué aux keras, cette compréhension devient naturelle en lisant simplement cette partie:
y_true: True labels. Tenseur Theano/TensorFlow.
y_pred: Prédictions. Tenseur Theano/TensorFlow de la même forme que y_true.
Les étiquettes vraies signifient les données vraies/cibles. Les étiquettes sont un mot mal choisi ici, ce ne sont vraiment que des "étiquettes" dans les modèles de classification.
Les prévisions désignent les résultats de votre modèle.
y_true
est la valeur vraie (étiquettes). et y_pred
sont les valeurs prédites par votre modèle NN.
La taille (forme) des tenseurs est déterminée par la taille des lots (nb_batches).