Après avoir lu les docs , j’ai sauvegardé un modèle dans TensorFlow
, voici mon code de démonstration:
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)
mais après cela, j'ai trouvé il y a 3 fichiers
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta
Et je ne peux pas restaurer le modèle en restaurant le fichier model.ckpt
, car ce type de fichier n'existe pas. Voici mon code
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
Alors, pourquoi il y a 3 fichiers?
Essaye ça:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
saver.restore(sess, "/tmp/model.ckpt")
La méthode de sauvegarde TensorFlow enregistre trois types de fichiers car elle stocke les structure graphique séparément de la valeurs variables. Le fichier .meta
décrit la structure du graphe enregistrée. Vous devez donc l'importer avant de restaurer le point de contrôle (sinon, il ne sait pas à quelles variables correspondent les valeurs de point de contrôle enregistrées).
Alternativement, vous pouvez faire ceci:
# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Now load the checkpoint variable values
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "/tmp/model.ckpt")
Même s'il n'y a pas de fichier nommé model.ckpt
, vous vous référez quand même au point de contrôle enregistré sous ce nom lors de la restauration. À partir du saver.py
code source :
Les utilisateurs doivent uniquement interagir avec le préfixe spécifié par l'utilisateur ... à la place de tout chemin physique.
meta file : décrit la structure de graphe enregistrée, comprend GraphDef, SaverDef, etc. puis appliquez tf.train.import_meta_graph('/tmp/model.ckpt.meta')
, restaurera Saver
et Graph
.
fichier index : c'est une table chaîne-chaîne immuable (tensorflow :: table :: Table). Chaque clé est le nom d'un tenseur et sa valeur est un BundleEntryProto sérialisé. Chaque BundleEntryProto décrit les métadonnées d'un tenseur: lequel des fichiers "données" contient le contenu d'un tenseur, le décalage dans ce fichier, la somme de contrôle, certaines données auxiliaires, etc.
fichier de données : il s'agit de la collection TensorBundle, enregistrez les valeurs de toutes les variables.
Je suis en train de restaurer des implémentations Word formées à partir de Word2Vec tutoriel.
Si vous avez créé plusieurs points de contrôle:
par exemple. les fichiers créés ressemblent à ceci
model.ckpt-55695.data-00000-of-00001
model.ckpt-55695.index
model.ckpt-55695.meta
essaye ça
def restore_session(self, session):
saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
saver.restore(session, './tmp/model.ckpt-55695')
en appelant restore_session ():
def test_Word2vec():
opts = Options()
with tf.Graph().as_default(), tf.Session() as session:
with tf.device("/cpu:0"):
model = Word2Vec(opts, session)
model.restore_session(session)
model.get_embedding("assistance")
Si vous avez formé un CNN avec des abandons, par exemple, vous pourriez faire ceci:
def predict(image, model_name):
"""
image -> single image, (width, height, channels)
model_name -> model file that was saved without any extensions
"""
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./' + model_name + '.meta')
saver.restore(sess, './' + model_name)
# Substitute 'logits' with your model
prediction = tf.argmax(logits, 1)
# 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})