web-dev-qa-db-fra.com

Pytorch ne prend pas en charge le vecteur unique?

Je suis très confus par la façon dont Pytorch traite les vecteurs à un seul point. Dans ce tutoriel , le réseau de neurones générera un vecteur un-chaud comme sortie. Pour autant que je comprends, la structure schématique du réseau de neurones dans le tutoriel devrait être comme:

enter image description here

Cependant, les labels ne sont pas au format vectoriel unique. J'obtiens le size suivant

print(labels.size())
print(outputs.size())

output>>> torch.Size([4]) 
output>>> torch.Size([4, 10])

Par miracle, je passe les outputs et labels à criterion=CrossEntropyLoss(), il n'y a aucune erreur du tout.

loss = criterion(outputs, labels) # How come it has no error?

Mon hypothèse:

Peut-être que pytorch convertit automatiquement le labels en une forme vectorielle unique. Donc, j'essaie de convertir les étiquettes en un vecteur unique avant de le passer à la fonction de perte.

def to_one_hot_vector(num_class, label):
    b = np.zeros((label.shape[0], num_class))
    b[np.arange(label.shape[0]), label] = 1

    return b

labels_one_hot = to_one_hot_vector(10,labels)
labels_one_hot = torch.Tensor(labels_one_hot)
labels_one_hot = labels_one_hot.type(torch.LongTensor)

loss = criterion(outputs, labels_one_hot) # Now it gives me error

Cependant, j'ai eu l'erreur suivante

RuntimeError: multi-cible non prise en charge dans /opt/pytorch/pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15

Ainsi, les vecteurs uniques ne sont pas pris en charge dans Pytorch? Comment Pytorch calcule-t-il le cross entropy Pour les deux tenseurs outputs = [1,0,0],[0,0,1] Et labels = [0,2]? Pour le moment, cela n'a aucun sens pour moi.

9
Raven Cheuk

Je suis confus au sujet de votre confusion. PyTorch indique clairement dans sa documentation pour CrossEntropyLoss que

Ce critère attend un indice de classe (0 à C-1) comme cible pour chaque valeur d'un tenseur 1D de mini-lot de taille

En d'autres termes, il a votre to_one_hot_vector fonction construite conceptuellement dans CEL et n'expose pas l'API one-hot. Notez que les vecteurs à un seul point sont inefficaces en mémoire par rapport au stockage des étiquettes de classe.

Si vous disposez de vecteurs uniques et devez aller au format d'étiquettes de classe (par exemple pour être compatible avec CEL), vous pouvez utiliser argmax comme ci-dessous:

import torch

labels = torch.tensor([1, 2, 3, 5])
one_hot = torch.zeros(4, 6)
one_hot[torch.arange(4), labels] = 1

reverted = torch.argmax(one_hot, dim=1)
assert (labels == reverted).all().item()
11
Jatentaki

Ce code vous aidera à la fois n encodage à chaud et encodage à chaud multiple:

import torch
batch_size=10
n_classes=5
target = torch.randint(high=5, size=(1,10)) # set size (2,10) for MHE
print(target)
y = torch.zeros(batch_size, n_classes)
y[range(y.shape[0]), target]=1
y

La sortie dans OHE

tensor([[4, 3, 2, 2, 4, 1, 1, 1, 4, 2]])

tensor([[0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0.]])

La sortie pour MHE quand je mets target = torch.randint(high=5, size=(2,10))

tensor([[3, 2, 4, 4, 2, 4, 0, 4, 4, 1],
        [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]])

tensor([[0., 0., 0., 1., 1.],
        [0., 1., 1., 0., 0.],
        [0., 1., 0., 0., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 1.],
        [1., 0., 0., 0., 1.],
        [0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 1., 0.]])

Si vous avez besoin de plusieurs OHE:

torch.nn.functional.one_hot(target)

tensor([[[0, 0, 0, 1, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1],
         [1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 1, 0, 0, 0]],

        [[0, 0, 0, 0, 1],
         [0, 1, 0, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 1, 0]]])
4
prosti