web-dev-qa-db-fra.com

Pytorch: Impossible d'appeler numpy () sur une variable qui nécessite un grad. Utilisez plutôt var.detach (). Numpy ()

J'ai une erreur dans mon code qui ne se corrige pas de quelque façon que ce soit.

L'erreur est simple, je renvoie une valeur:

torch.exp(-LL_total/T_total)

et récupérez l'erreur plus tard dans le pipeline:

RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

Des solutions telles que cpu().detach().numpy() donnent la même erreur.

Comment pourrais-je le réparer? Merci.

10
tstseby

Erreur reproduite

import torch

tensor1 = torch.tensor([1.0,2.0],requires_grad=True)

print(tensor1)
print(type(tensor1))

tensor1 = tensor1.numpy()

print(tensor1)
print(type(tensor1))

ce qui conduit exactement à la même erreur pour la ligne tensor1 = tensor1.numpy():

tensor([1., 2.], requires_grad=True)
<class 'torch.Tensor'>
Traceback (most recent call last):
  File "/home/badScript.py", line 8, in <module>
    tensor1 = tensor1.numpy()
RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

Process finished with exit code 1

Solution générique

cela vous a été suggéré dans votre message d'erreur, remplacez simplement var par le nom de votre variable

import torch

tensor1 = torch.tensor([1.0,2.0],requires_grad=True)

print(tensor1)
print(type(tensor1))

tensor1 = tensor1.detach().numpy()

print(tensor1)
print(type(tensor1))

qui retourne comme prévu

tensor([1., 2.], requires_grad=True)
<class 'torch.Tensor'>
[1. 2.]
<class 'numpy.ndarray'>

Process finished with exit code 0

Quelques explications

Vous devez convertir votre tenseur en un autre tenseur qui ne nécessite pas de gradient en plus de sa définition de valeur réelle. Cet autre tenseur peut être converti en un tableau numpy. Cf. ce post de discuter.pytorch . (Je pense, plus précisément, qu'il faut le faire pour extraire le tenseur réel de son enveloppe pytorch Variable, cf. cet autre article de discussion.pytorch ).

8
Blupon

J'ai eu le même message d'erreur mais c'était pour dessiner un nuage de points sur matplotlib.

Il y a 2 façons dont je pourrais sortir de ce message d'erreur:

  1. importer le fastai.basics bibliothèque avec: from fastai.basics import *

  2. Si vous utilisez uniquement la bibliothèque torch, n'oubliez pas de retirer la requires_grad avec :

    with torch.no_grad():
        (your code)
    
8
Rickantonais