web-dev-qa-db-fra.com

Le meilleur moyen de sauvegarder un modèle formé dans PyTorch?

Je cherchais d'autres moyens de sauvegarder un modèle formé dans PyTorch. Jusqu'à présent, j'ai trouvé deux alternatives.

  1. torch.save () pour enregistrer un modèle et torch.load () pour charger un modèle.
  2. model.state_dict () pour enregistrer un modèle formé et model.load_state_dict () pour charger le modèle enregistré.

Je suis tombé sur ceci discussion où l'approche 2 est recommandée par rapport à l'approche 1.

Ma question est, pourquoi la deuxième approche est préférée? Est-ce uniquement parce que les modules torch.nn ont ces deux fonctions et nous sommes encouragés à les utiliser?

123
Wasi Ahmad

J'ai trouvé cette page sur leur dépôt github, je vais simplement coller le contenu ici.


Approche recommandée pour la sauvegarde d'un modèle

Il existe deux approches principales pour la sérialisation et la restauration d'un modèle.

Le premier (recommandé) enregistre et charge uniquement les paramètres du modèle:

torch.save(the_model.state_dict(), PATH)

Puis plus tard:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

La seconde enregistre et charge le modèle entier:

torch.save(the_model, PATH)

Puis plus tard:

the_model = torch.load(PATH)

Cependant, dans ce cas, les données sérialisées sont liées aux classes spécifiques et à la structure de répertoires exacte utilisée, de sorte qu'elles peuvent se rompre de différentes manières lorsqu'elles sont utilisées dans d'autres projets ou après certains refactors sérieux.

143
dontloo

Cela dépend de ce que vous voulez faire.

Cas n ° 1: enregistrez le modèle pour l'utiliser vous-même à des fins d'inférence: vous enregistrez le modèle, vous le restaurez, puis vous changez le modèle en mode d'évaluation. Ceci est fait car vous avez généralement les couches BatchNorm et Dropout qui sont par défaut en mode train à la construction:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Cas n ° 2: Enregistrer le modèle pour reprendre l'entraînement ultérieurement: Si vous devez continuer à former le modèle que vous êtes sur le point d'enregistrer, vous devez enregistrer davantage que le modèle. Vous devez également enregistrer l'état de l'optimiseur, des époques, du score, etc. Vous le feriez comme ceci:

state = {
    'Epoch': Epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Pour reprendre l'entraînement, vous feriez les choses suivantes: state = torch.load(filepath), puis, pour restaurer l'état de chaque objet, comme ceci:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Puisque vous reprenez votre entraînement, NE PAS PAS appeler model.eval() une fois que vous avez restauré les états lors du chargement.

Cas n ° 3: modèle à utiliser par quelqu'un d'autre sans accès à votre code: Dans Tensorflow, vous pouvez créer un fichier .pb qui définit à la fois l'architecture et les poids du modèle. Ceci est très pratique, en particulier lorsque vous utilisez Tensorflow serve. La manière équivalente de faire cela à Pytorch serait:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Cette solution n’est toujours pas à l’abri et pytorch subissant encore de nombreux changements, je ne le recommanderais pas.

90
Jadiel de Armas

La bibliothèque pickle Python implémente des protocoles binaires pour la sérialisation et la désérialisation d'un objet Python.

Lorsque vous import torch (ou lorsque vous utilisez PyTorch), il import pickle et vous n'avez pas besoin d'appeler directement pickle.dump() et pickle.load(), qui sont les méthodes permettant de sauvegarder et de charger l'objet.

En fait, torch.save() et torch.load() encapsuleront pickle.dump() et pickle.load() pour vous.

A state_dict l'autre réponse mentionnée mérite quelques notes supplémentaires.

Qu'est-ce que state_dict avons-nous à l'intérieur de PyTorch? Il y a en fait deux state_dicts.

Le modèle PyTorch est torch.nn.Module a l'appel model.parameters() pour obtenir les paramètres pouvant être appris (w et b). Ces paramètres, une fois définis de manière aléatoire, seront mis à jour au fil du temps. Les paramètres pouvant être appris sont le premier state_dict.

Le second state_dict est l'état de l'optimiseur. Vous vous rappelez que l'optimiseur est utilisé pour améliorer nos paramètres pouvant être appris. Mais l'optimiseur state_dict est corrigé. Rien à apprendre là-dedans.

Étant donné que les objets state_dict sont des dictionnaires Python, ils peuvent être facilement enregistrés, mis à jour, modifiés et restaurés, ce qui ajoute beaucoup de modularité aux modèles PyTorch et aux optimiseurs.

Créons un modèle très simple pour expliquer ceci:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Ce code produira les informations suivantes:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Notez qu'il s'agit d'un modèle minimal. Vous pouvez essayer d’ajouter une pile séquentielle

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Notez que seules les couches avec des paramètres pouvant être appris (couches convolutives, couches linéaires, etc.) et les tampons enregistrés (couches batchnorm) ont des entrées dans le state_dict du modèle.

Les objets non apprenables appartiennent à l’objet optimiseur state_dict qui contient des informations sur l’état de l’optimiseur ainsi que sur les hyperparamètres utilisés.

Le reste de l'histoire est la même; dans la phase d'inférence (c'est une phase où nous utilisons le modèle après l'entraînement) pour prédire; nous prédisons en fonction des paramètres que nous avons appris. Donc, pour l'inférence, il suffit de sauvegarder les paramètres model.state_dict().

torch.save(model.state_dict(), filepath)

Et pour utiliser plus tard model.load_state_dict (torch.load (chemin du fichier)) model.eval ()

Remarque: N'oubliez pas la dernière ligne model.eval() c'est essentiel après le chargement du modèle.

Aussi, n'essayez pas de sauvegarder torch.save(model.parameters(), filepath). La model.parameters() n'est que l'objet générateur.

De l'autre côté, torch.save(model, filepath) enregistre l'objet de modèle lui-même, mais gardez à l'esprit que le modèle n'a pas le state_dict de l'optimiseur. Vérifiez l'autre excellente réponse de @Jadiel de Armas pour enregistrer le dict d'état de l'optimiseur.

1
prosti

ne convention PyTorch courante consiste à enregistrer les modèles avec une extension de fichier .pt ou .pth.

Enregistrer/Charger le modèle entier Save:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Charge:

La classe de modèle doit être définie quelque part

model = torch.load(PATH)
model.eval()
0
harsh