web-dev-qa-db-fra.com

Comprendre torch.nn.Parameter

Je suis nouveau sur Pytorch et j'ai du mal à comprendre comment fonctionne torch.nn.Parameter().

J'ai parcouru la documentation dans https://pytorch.org/docs/stable/nn.html mais je n'y trouve que très peu de sens.

Puis-je avoir une aide s'il vous plait?

L'extrait de code sur lequel je travaille:

def __init__(self, weight):
    super(Net, self).__init__()
    # initializes the weights of the convolutional layer to be the weights of the 4 defined filters
    k_height, k_width = weight.shape[2:]
    # assumes there are 4 grayscale filters
    self.conv = nn.Conv2d(1, 4, kernel_size=(k_height, k_width), bias=False)
    self.conv.weight = torch.nn.Parameter(weight)
22
Vineet Pandey

Je vais le décomposer pour vous. Comme vous le savez peut-être, les tenseurs sont des matrices multidimensionnelles. Le paramètre, dans sa forme brute, est un tenseur, c'est-à-dire une matrice multidimensionnelle. Il sous-classe la classe Variable.

La différence entre une variable et un paramètre intervient lorsqu'elle est associée à un module. Lorsqu'un paramètre est associé à un module en tant qu'attribut de modèle, il est automatiquement ajouté à la liste des paramètres et est accessible à l'aide de l'itérateur 'parameters'.

Initialement dans Torch, une variable (pouvant par exemple être un état intermédiaire) serait également ajoutée en tant que paramètre du modèle lors de l'affectation. Par la suite, des cas d'utilisation ont été identifiés, indiquant la nécessité de mettre en cache les variables au lieu de les ajouter à la liste de paramètres.

Un de ces cas, comme mentionné dans la documentation, est celui de RNN, dans lequel vous devez enregistrer le dernier état masqué pour ne pas avoir à le passer encore et encore. La nécessité de mettre en cache une variable au lieu de l’enregistrer automatiquement en tant que paramètre du modèle explique pourquoi nous disposons d’un moyen explicite d’enregistrer des paramètres dans notre modèle, à savoir la classe nn.Parameter.

Par exemple, exécutez le code suivant -

import torch
import torch.nn as nn
from torch.optim import Adam

class NN_Network(nn.Module):
    def __init__(self,in_dim,hid,out_dim):
        super(NN_Network, self).__init__()
        self.linear1 = nn.Linear(in_dim,hid)
        self.linear2 = nn.Linear(hid,out_dim)
        self.linear1.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear1.bias = torch.nn.Parameter(torch.ones(hid))
        self.linear2.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear2.bias = torch.nn.Parameter(torch.ones(hid))

    def forward(self, input_array):
        h = self.linear1(input_array)
        y_pred = self.linear2(h)
        return y_pred

in_d = 5
hidn = 2
out_d = 3
net = NN_Network(in_d, hidn, out_d)

Maintenant, vérifiez la liste des paramètres associés à ce modèle -

for param in net.parameters():
    print(type(param.data), param.size())

""" Output
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
"""

Ou essayez,

list(net.parameters())

Ceci peut facilement être transmis à votre optimiseur -

opt = Adam(net.parameters(), learning_rate=0.001)

Notez également que require_grad est défini par défaut pour les paramètres.

49
Astha Sharma

Les dernières versions de PyTorch n’ont que Tensors, il est apparu que le concept de Variable était obsolète.

Paramètres ne sont que des tenseurs limités au module dans lequel ils ont été définis (dans la méthode du constructeur de module __init__).

Ils apparaîtront à l'intérieur de module.parameters(). Cela est pratique lorsque vous construisez vos modules personnalisés, qui apprennent grâce à ces paramètres de descente de gradient.

Tout ce qui est vrai pour les tenseurs PyTorch est vrai pour les paramètres, car ce sont des tenseurs.

De plus, si le module passe en GPU, les paramètres le sont également. Si le module est enregistré, les paramètres seront également enregistrés.

Il existe un concept similaire pour modéliser des paramètres appelés tampons .

Ceux-ci sont nommés tenseurs à l'intérieur du module, mais ces tenseurs ne sont pas censés apprendre via la descente de gradient. Vous pouvez plutôt penser qu'il s'agit de variables. Vous mettrez à jour vos tampons nommés à l'intérieur du module forward() comme vous le souhaitez.

Pour les tampons, il est également vrai qu’ils iront au GPU avec le module et qu’ils seront sauvegardés avec le module.

enter image description here

3
prosti