web-dev-qa-db-fra.com

Comment trouver les noms de variables et les valeurs enregistrées dans un point de contrôle?

Je souhaite voir les variables enregistrées dans un point de contrôle TensorFlow avec leurs valeurs. Comment trouver les noms de variable enregistrés dans un point de contrôle TensorFlow?

J'ai utilisé tf.train.NewCheckpointReader qui est expliqué ici . Mais, cela n’est pas indiqué dans la documentation de TensorFlow. Y-a t'il une autre possibilité?

23
Tavakoli

Vous pouvez utiliser l'outil inspect_checkpoint.py .

Ainsi, par exemple, si vous avez stocké le point de contrôle dans le répertoire actuel, vous pouvez imprimer les variables et leurs valeurs comme suit:

import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file


latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
11
keveman

Exemple d'utilisation:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")

# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')

# List contents of v0 tensor.
# Example output: tensor_name:  v0 [[[[  9.27958265e-02   7.40226209e-02   4.52989563e-02   3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')

# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')

Update: L'argument all_tensors a été ajouté à print_tensors_in_checkpoint_file puisque Tensorflow 0.12.0-rc0 vous devrez peut-être ajouter all_tensors=False ou all_tensors=True si nécessaire.

Méthode alternative:

from tensorflow.python import pywrap_tensorflow
import os

checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()

for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names

J'espère que ça aide.

36
sagunms

Quelques détails supplémentaires.

Si votre modèle est enregistré au format V2, par exemple, si nous avons les fichiers suivants dans le répertoire /my/dir/

model-10000.data-00000-of-00001
model-10000.index
model-10000.meta

alors le paramètre file_name ne devrait être que le préfixe, c’est-à-dire

print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)

Voir https://github.com/tensorflow/tensorflow/issues/7696 pour une discussion.

9
deeplearning

Ajout de plus de détails sur les paramètres à print_tensors_in_checkpoint_file

file_name: pas un fichier physique, juste le préfixe des noms de fichiers

Si aucun tensor_name n'est fourni, affiche les noms et les formes du tenseur dans le fichier de point de contrôle. Si tensor_name est fourni, affiche le contenu du tenseur. ( inspect_checkpoint.py )

Si all_tensor_names est True, imprime tous les noms de tenseur

Si all_tensor est 'True`, imprime tous les noms de tenseur et le contenu correspondant.

N.B.all_tensor et all_tensor_names remplacera tensor_name

0
BugKiller