web-dev-qa-db-fra.com

RuntimeError: objet attendu de type scalaire long mais de type scalaire float pour l'argument # 2 'mat2' Comment le réparer?


import torch.nn as nn 
import torch 
import torch.optim as optim
import itertools

class net1(nn.Module):
    def __init__(self):
        super(net1,self).__init__()

        self.pipe = nn.Sequential(
            nn.Linear(10,10),
            nn.ReLU()
        )

    def forward(self,x):
        return self.pipe(x.long())

class net2(nn.Module):
    def __init__(self):
        super(net2,self).__init__()

        self.pipe = nn.Sequential(
            nn.Linear(10,20),
            nn.ReLU(),
            nn.Linear(20,10)
        )

    def forward(self,x):
        return self.pipe(x.long())



netFIRST = net1()
netSECOND = net2()

learning_rate = 0.001

opt = optim.Adam(itertools.chain(netFIRST.parameters(),netSECOND.parameters()), lr=learning_rate)

epochs = 1000

x = torch.tensor([1,2,3,4,5,6,7,8,9,10],dtype=torch.long)
y = torch.tensor([10,9,8,7,6,5,4,3,2,1],dtype=torch.long)


for Epoch in range(epochs):
    opt.zero_grad()

    prediction = netSECOND(netFIRST(x))
    loss = (y.long() - prediction)**2
    loss.backward()

    print(loss)
    print(prediction)
    opt.step()

erreur:

ligne 49, dans la prédiction = Netsecond (NetFirst (X))

ligne 1371, linéaire; sortie = entrée.matmul (poids.t ())

RuntimeError: objet attendu de type scalaire long mais de type scalaire float pour argument # 2 'mat2'

Je ne vois pas vraiment ce que je fais mal. J'ai essayé de tout transformer dans un Long de toutes les tâches possibles. Je n'ai vraiment pas la voie à taper fonctionne pour Pytorch. La dernière fois que j'ai essayé quelque chose avec une seule couche et cela m'a forcé à utiliser le type int. Quelqu'un pourrait-il expliquer comment la saisie est établie dans Pytorch et comment prévenir et corriger des erreurs comme celle-ci ?? Beaucoup, je veux dire énormément de remerciements à l'avance, ce problème me dérange vraiment et je ne peux pas sembler le réparer, peu importe ce que j'essaie.

4
hal9000

Les poids sont des flotteurs, les intrants sont longs. Ceci n'est pas autorisé. En fait, je ne pense pas que la flambeau soutient autre chose que des flotteurs dans des réseaux de neurones.

Si vous supprimez tous appels vers longtemps et définissez votre contribution en tant que flotteurs, cela fonctionnera (cela fonctionnera-t-il).

(Vous obtiendrez ensuite une autre erreur non liée: vous devez résumer votre perte)

2
Bzazz