Je souhaite voir les variables enregistrées dans un point de contrôle TensorFlow avec leurs valeurs. Comment trouver les noms de variable enregistrés dans un point de contrôle TensorFlow?
J'ai utilisé tf.train.NewCheckpointReader
qui est expliqué ici . Mais, cela n’est pas indiqué dans la documentation de TensorFlow. Y-a t'il une autre possibilité?
Vous pouvez utiliser l'outil inspect_checkpoint.py
.
Ainsi, par exemple, si vous avez stocké le point de contrôle dans le répertoire actuel, vous pouvez imprimer les variables et leurs valeurs comme suit:
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
Exemple d'utilisation:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')
# List contents of v0 tensor.
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')
# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')
Update: L'argument all_tensors
a été ajouté à print_tensors_in_checkpoint_file
puisque Tensorflow 0.12.0-rc0 vous devrez peut-être ajouter all_tensors=False
ou all_tensors=True
si nécessaire.
Méthode alternative:
from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key)) # Remove this is you want to print only variable names
J'espère que ça aide.
Quelques détails supplémentaires.
Si votre modèle est enregistré au format V2, par exemple, si nous avons les fichiers suivants dans le répertoire /my/dir/
model-10000.data-00000-of-00001
model-10000.index
model-10000.meta
alors le paramètre file_name
ne devrait être que le préfixe, c’est-à-dire
print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)
Voir https://github.com/tensorflow/tensorflow/issues/7696 pour une discussion.
Ajout de plus de détails sur les paramètres à print_tensors_in_checkpoint_file
file_name
: pas un fichier physique, juste le préfixe des noms de fichiers
Si aucun tensor_name
n'est fourni, affiche les noms et les formes du tenseur dans le fichier de point de contrôle. Si tensor_name
est fourni, affiche le contenu du tenseur. ( inspect_checkpoint.py )
Si all_tensor_names
est True
, imprime tous les noms de tenseur
Si all_tensor
est 'True`, imprime tous les noms de tenseur et le contenu correspondant.
N.B.all_tensor
et all_tensor_names
remplacera tensor_name