web-dev-qa-db-fra.com

Pytorch softmax: Quelle dimension utiliser?

La fonction torch.nn.functional.softmax prend deux paramètres: input et dim. Selon sa documentation, l'opération softmax est appliquée à toutes les tranches de input le long de la dim spécifiée et les redimensionnera de sorte que les éléments se trouvent dans la plage (0, 1) et totalisent 1. 

Que l'entrée soit:

input = torch.randn((3, 4, 5, 6))

Supposons que je veuille ce qui suit, pour que chaque entrée de ce tableau soit 1:

sum = torch.sum(input, dim = 3) # sum's size is (3, 4, 5, 1)

Comment dois-je appliquer softmax?

softmax(input, dim = 0) # Way Number 0
softmax(input, dim = 1) # Way Number 1
softmax(input, dim = 2) # Way Number 2
softmax(input, dim = 3) # Way Number 3

Mon intuition me dit que c'est le dernier, mais je ne suis pas sûr. L'anglais n'est pas ma langue maternelle et l'utilisation de la Parole along m'a semblé confuse à cause de cela.

Je ne suis pas très clair sur ce que "long" signifie, alors je vais utiliser un exemple qui pourrait clarifier les choses. Supposons que nous ayons un tenseur de taille (s1, s2, s3, s4) et que cela se produise

7
Jadiel de Armas

Le moyen le plus simple que je puisse vous faire comprendre est: supposons qu'un tenseur de forme (s1, s2, s3, s4) soit attribué et que vous souhaitiez que la somme de toutes les entrées du dernier axe soit égale à 1.

sum = torch.sum(input, dim = 3) # input is of shape (s1, s2, s3, s4)

Ensuite, vous devez appeler le softmax comme:

softmax(input, dim = 3)

Pour comprendre facilement, vous pouvez considérer un tenseur 4d de forme (s1, s2, s3, s4) comme un tenseur 2d ou une matrice de forme (s1*s2*s3, s4). Maintenant, si vous voulez que la matrice contienne des valeurs dans chaque ligne (axe = 0) ou colonne (axe = 1) qui totalisent 1, vous pouvez simplement appeler la fonction softmax sur le tenseur 2d comme suit:

softmax(input, dim = 0) # normalizes values along axis 0
softmax(input, dim = 1) # normalizes values along axis 1

Vous pouvez voir l'exemple cité par Steven dans son answer .

5
Wasi Ahmad

 enter image description here

La réponse de Steven ci-dessus n'est pas correcte. Voir l'instantané ci-dessous. C'est en fait l'inverse.

9
sww