web-dev-qa-db-fra.com

Comment corriger RuntimeError "Objet attendu de type scalaire Float mais obtenu type scalaire Double pour argument"?

J'essaie de former un classificateur via PyTorch. Cependant, je rencontre des problèmes de formation lorsque j'alimente le modèle avec des données de formation. J'obtiens cette erreur sur y_pred = model(X_trainTensor):

RuntimeError: objet attendu de type scalaire Float mais obtenu type scalaire Double pour l'argument # 4 'mat1'

Voici les éléments clés de mon code:

# Hyper-parameters 
D_in = 47  # there are 47 parameters I investigate
H = 33
D_out = 2  # output should be either 1 or 0
# Format and load the data
y = np.array( df['target'] )
X = np.array( df.drop(columns = ['target'], axis = 1) )
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size = 0.8)  # split training/test data

X_trainTensor = torch.from_numpy(X_train) # convert to tensors
y_trainTensor = torch.from_numpy(y_train)
X_testTensor = torch.from_numpy(X_test)
y_testTensor = torch.from_numpy(y_test)
# Define the model
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
    nn.LogSoftmax(dim = 1)
)
# Define the loss function
loss_fn = torch.nn.NLLLoss() 
for i in range(50):
    y_pred = model(X_trainTensor)
    loss = loss_fn(y_pred, y_trainTensor)
    model.zero_grad()
    loss.backward()
    with torch.no_grad():       
        for param in model.parameters():
            param -= learning_rate * param.grad
22
Shawn Zhang

Le problème peut être résolu en définissant le type de données d'entrée sur Double, c'est-à-dire torch.float32

J'espère que le problème est survenu car votre type de données est torch.float16

0
Rohit Choudhary