web-dev-qa-db-fra.com

PyTorch - Comment obtenir le taux d'apprentissage pendant la formation?

Pendant la formation, j'aimerais connaître la valeur de learning_rate. Que devrais-je faire?

C'est mon code, comme ceci:

my_optimizer = torch.optim.SGD(my_model.parameters(), 
                               lr=0.001, 
                               momentum=0.99, 
                               weight_decay=2e-3)

Je vous remercie.

Pour un seul groupe de paramètres comme dans l'exemple que vous avez donné, vous pouvez utiliser cette fonction et l'appeler pendant la formation pour obtenir le taux d'apprentissage actuel:

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
12
blue-phoenox