web-dev-qa-db-fra.com

pytorch comment définir .requires_grad False

Je veux congeler une partie de mon modèle. Après les documents officiels:

with torch.no_grad():
    linear = nn.Linear(1, 1)
    linear.eval()
    print(linear.weight.requires_grad)

Mais il imprime True au lieu de False. Que dois-je faire si je veux définir le modèle en mode eval?

16
Qian Wang

require_grad = False

Si vous souhaitez geler une partie de votre modèle et entraîner le reste, vous pouvez définir requires_grad Des paramètres à geler sur False.

Par exemple, si vous souhaitez uniquement conserver la partie convolutionnelle de VGG16 fixe:

model = torchvision.models.vgg16(pretrained=True)
for param in model.features.parameters():
    param.requires_grad = False

En commutant les indicateurs requires_grad Sur False, aucun tampon intermédiaire ne sera enregistré, jusqu'à ce que le calcul atteigne un point où l'une des entrées de l'opération nécessite le dégradé.

torch.no_grad ()

Utiliser le gestionnaire de contexte torch.no_grad Est un moyen différent d’atteindre cet objectif: dans le contexte no_grad, Tous les résultats des calculs auront requires_grad=False, Même si les entrées ont requires_grad=True. Notez que vous ne pourrez pas propager le dégradé en arrière-plan avant le no_grad. Par exemple:

x = torch.randn(2, 2)
x.requires_grad = True

lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)
x1 = lin0(x)
with torch.no_grad():    
    x2 = lin1(x1)
x3 = lin2(x2)
x3.sum().backward()
print(lin0.weight.grad, lin1.weight.grad, lin2.weight.grad)

les sorties:

(None, None, tensor([[-1.4481, -1.1789],
         [-1.4481, -1.1789]]))

Ici lin1.weight.requires_grad Était Vrai, mais le dégradé n'a pas été calculé car l'optimisation a été effectuée dans le contexte no_grad.

model.eval ()

Si votre objectif n'est pas d'affiner le réglage, mais de définir votre modèle en mode inférence, le moyen le plus pratique consiste à utiliser le gestionnaire de contexte torch.no_grad. Dans ce cas, vous devez également définir votre modèle sur le mode evaluation. Pour ce faire, appelez eval() sur le nn.Module, Par exemple:

model = torchvision.models.vgg16(pretrained=True)
model.eval()

Cette opération définit l'attribut self.training Des couches sur False. En pratique, cela modifiera le comportement d'opérations telles que Dropout ou BatchNorm qui doivent se comporter différemment à temps de formation et de test.

20
iacolippo

Voici le chemin

linear = nn.Linear(1,1)

for param in linear.parameters():
    param.requires_grad = False

with torch.no_grad():
    linear.eval()
    print(linear.weight.requires_grad)

SORTIE: Faux

5
Salih Karagoz

Pour compléter la réponse de @ Salih_Karagoz, vous avez également le contexte torch.set_grad_enabled() (documentation supplémentaire ici ), qui peut facilement être utilisé pour passer d'un mode train à l'autre. :

linear = nn.Linear(1,1)

is_train = False
with torch.set_grad_enabled(is_train):
    linear.eval()
    print(linear.weight.requires_grad)
3
benjaminplanche

Agréable. L'astuce consiste à vérifier que lorsque vous définissez une couche linéaire, les paramètres auront par défaut requires_grad=True, Car nous aimerions apprendre, n'est-ce pas?

l = nn.Linear(1, 1)
p = l.parameters()
for _ in p:
    print (_)

# Parameter containing:
# tensor([[-0.3258]], requires_grad=True)
# Parameter containing:
# tensor([0.6040], requires_grad=True)    

L'autre construction,

with torch.no_grad():

Cela signifie que vous ne pouvez pas apprendre ici.

Ainsi, votre code montre simplement que vous êtes capable d'apprendre, même si vous êtes dans torch.no_grad() où l'apprentissage est interdit.

with torch.no_grad():
    linear = nn.Linear(1, 1)
    linear.eval()
    print(linear.weight.requires_grad) #true

Si vous avez vraiment l'intention de désactiver requires_grad Pour le paramètre weight, vous pouvez le faire aussi avec:

linear.weight.requires_grad_(False)

ou

linear.weight.requires_grad = False

Donc, votre code peut devenir comme ceci:

with torch.no_grad():
    linear = nn.Linear(1, 1)
    linear.weight.requires_grad_(False)
    linear.eval()
    print(linear.weight.requires_grad)

Si vous envisagez de passer à require_grad pour tous les paramètres d'un module:

l = nn.Linear(1, 1)
for _ in l.parameters():
    _.requires_grad_(False)
    print(_)
0
prosti

Ceci tutoriel peut aider.

En quelques mots, je pense qu'un bon moyen pour cette question pourrait être:

linear = nn.Linear(1,1)

for param in linear.parameters():
    param.requires_grad = False

linear.eval()
print(linear.weight.requires_grad)

0
Meiqi