web-dev-qa-db-fra.com

Mode TensorFlow Eager: Comment restaurer un modèle à partir d'un point de contrôle?

J'ai formé un modèle CNN en mode passionné TensorFlow. Maintenant, j'essaie de restaurer le modèle formé à partir d'un fichier de point de contrôle, mais sans succès.

Tous les exemples (illustrés ci-dessous) que j'ai trouvés parlent de la restauration d'un point de contrôle dans une session. Mais ce dont j'ai besoin, c’est de restaurer le modèle en mode rapide, c’est-à-dire sans créer de session. 

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Fondamentalement, ce dont j'ai besoin, c'est quelque chose comme:

tfe.enable_eager_execution()
model = tfe.restore('model.ckpt')
model.predict(...)

et puis je peux utiliser le modèle pour faire des prédictions.

Puis-je avoir une aide s'il vous plait?

Mettre à jour

L’exemple de code est disponible à l’adresse: mnist eager mode demo

J'ai essayé de suivre les étapes de la réponse de @Jay Shah et cela a presque fonctionné, mais le modèle restauré ne contient aucune variable.

tfe.save_network_checkpoint(model,'./test/my_model.ckpt')

Out[58]:
'./test/my_model.ckpt-1720'

model2 = MNISTModel()
tfe.restore_network_checkpoint(model2,'./test/my_model.ckpt-1720')
model2.variables

Out[72]:
[]

Le modèle d'origine contient de nombreuses variables:

model.variables

[<tf.Variable 'mnist_model_1/conv2d/kernel:0' shape=(5, 5, 1, 32) dtype=float32, numpy=
 array([[[[ -8.25184360e-02,   6.77833706e-03,   6.97569922e-02,...
7
Allen

Eager Execution est toujours une nouvelle fonctionnalité de TensorFlow et n'était pas inclus dans la dernière version. Par conséquent, toutes les fonctionnalités ne sont pas prises en charge. Heureusement, le chargement d'un modèle à partir d'un point de contrôle enregistré l'est. 

Vous devrez utiliser la classe tfe.Saver (qui est un wrapper fin sur la classe tf.train.Saver), et votre code devrait ressembler à ceci:

saver = tfe.Saver([x, y])
saver.restore('/tmp/ckpt')

Où [x, y] représente la liste des variables et/ou des modèles que vous souhaitez restaurer. Cela doit correspondre précisément aux variables transmises lors de la création initiale de l'économiseur ayant créé le point de contrôle.

Vous trouverez plus de détails, y compris un exemple de code, ici , ainsi que les détails de l'API de l'économiseur ici .

7
mr_snuffles

J'aime partager TFLearn bibliothèque qui est Deep learning library featuring a higher-level API for TensorFlow. Avec l'aide de cette bibliothèque, vous pouvez facilement save and restore un modèle.

Sauvegarder un modèle

model = tflearn.DNN(net) #Here 'net' is your designed network model. 
#This is a sample example for training the model
model.fit(train_x, train_y, n_Epoch=10, validation_set=(test_x, test_y), batch_size=10, show_metric=True)
model.save("model_name.ckpt")

Restaurer un modèle

model = tflearn.DNN(net)
model.load("model_name.ckpt")

Pour plus d'exemples de tflearn, vous pouvez consulter certains sites tels que ...

1
R.A.Munna

Enregistrement de variables avec tfe.Saver().save():

for Epoch in range(epochs):
    train_and_optimize()
    all_variables = model.variables + optimizer.variables()

    # save the varibles 
    tfe.Saver(all_variables).save(checkpoint_prefix)

Et puis rechargez les variables sauvegardées avec tfe.Saver().restore():

tfe.Saver((model.variables + optimizer.variables())).restore(checkpoint_prefix)

Ensuite, le modèle est chargé avec les variables enregistrées et il n'est pas nécessaire de créer une nouvelle variable, comme dans la réponse de @Stefan Falk.

0
ywfu