Comment initialiser les poids et les biais (par exemple, avec l'initialisation He ou Xavier) dans un réseau dans PyTorch?
Pour initialiser les poids d'une seule couche, utilisez une fonction de torch.nn.init
. Par exemple:
conv1 = torch.nn.Conv2d(...)
torch.nn.init.xavier_uniform(conv1.weight)
Alternativement, vous pouvez modifier les paramètres en écrivant dans conv1.weight.data
(qui est un torch.Tensor
). Exemple:
conv1.weight.data.fill_(0.01)
Il en va de même pour les biais:
conv1.bias.data.fill_(0.01)
nn.Sequential
ou personnalisé nn.Module
Passez une fonction d’initialisation à torch.nn.Module.apply
. Il initialisera les poids dans le nn.Module
entier de manière récursive.
apply (fn _): Applique
fn
de manière récursive à chaque sous-module (renvoyé par.children()
) ainsi qu'à self. Une utilisation typique inclut l’initialisation des paramètres d’un modèle (voir aussi torch-nn-init).
Exemple:
def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform(m.weight)
m.bias.data.fill_(0.01)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
import torch.nn as nn
# a simple network
Rand_net = nn.Sequential(nn.Linear(in_features, h_size),
nn.BatchNorm1d(h_size),
nn.ReLU(),
nn.Linear(h_size, h_size),
nn.BatchNorm1d(h_size),
nn.ReLU(),
nn.Linear(h_size, 1),
nn.ReLU())
# initialization function, first checks the module type,
# then applies the desired changes to the weights
def init_normal(m):
if type(m) == nn.Linear:
nn.init.uniform_(m.weight)
# use the modules apply function to recursively apply the initialization
Rand_net.apply(init_normal)
Désolé d'être si tard, j'espère que ma réponse aidera.
Pour initialiser des poids avec un normal distribution
, utilisez:
torch.nn.init.normal_(tensor, mean=0, std=1)
Ou pour utiliser un constant distribution
écrire:
torch.nn.init.constant_(tensor, value)
Ou pour utiliser un uniform distribution
écrire:
torch.nn.init.uniform_(tensor, a=0, b=1) # a: lower_bound, b: upper_bound
Vous pouvez vérifier d’autres méthodes d’initialisation des tenseurs ici