J'essaie actuellement d'exporter un modèle TensorFlow formé en tant que fichier ProtoBuf pour l'utiliser avec l'API TensorFlow C++ sur Android. Par conséquent, j'utilise le script freeze_graph.py
.
J'ai exporté mon modèle en utilisant tf.train.write_graph
:
tf.train.write_graph(graph_def, FLAGS.save_path, out_name, as_text=True)
et j'utilise un point de contrôle enregistré avec tf.train.Saver
.
J'appelle freeze_graph.py
comme décrit en haut du script. Après avoir compilé, je cours
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=<path_to_protobuf_file> \
--input_checkpoint=<model_name>.ckpt-10000 \
--output_graph=<output_protobuf_file_path> \
--output_node_names=dropout/mul_1
Cela me donne le message d'erreur suivant:
TypeError: Cannot interpret feed_dict key as Tensor: The name 'save/Const:0' refers to a Tensor which does not exist. The operation, 'save/Const', does not exist in the graph.
Comme l’erreur indique, je n’ai pas de tenseur save/Const:0
dans mon modèle exporté. Cependant, le code freeze_graph.py
indique que l’on peut spécifier ce nom de tenseur par le drapeau filename_tensor_name
. Malheureusement, je ne trouve aucune information sur ce que ce tenseur devrait être et comment le régler correctement pour mon modèle.
Quelqu'un peut-il me dire comment produire un tenseur save/Const:0
dans mon modèle ProtoBuf exporté ou comment définir le drapeau filename_tensor_name
correctement?
L’indicateur --filename_tensor_name
permet de spécifier le nom d’un tenseur d’espace réservé créé lors de la construction de tf.train.Saver
pour votre modèle. *
Dans votre programme d'origine, vous pouvez imprimer la valeur de saver.saver_def.filename_tensor_name
pour obtenir la valeur que vous devez transmettre pour cet indicateur. Vous pouvez également vouloir imprimer la valeur de saver.saver_def.restore_op_name
pour obtenir une valeur pour le drapeau --restore_op_name
(car je suppose que la valeur par défaut ne sera pas correcte pour votre graphique).
Sinon, le tampon de protocole tf.train.SaverDef
inclut toutes les informations nécessaires à la reconstruction des informations pertinentes pour ces indicateurs. Si vous préférez, vous pouvez écrire saver.saver_def
dans un fichier et transmettre le nom de ce fichier sous la forme de l'indicateur --input_saver
à freeze_graph.py
.
* L'étendue du nom par défaut pour tf.train.Saver
est "save/"
et l'espace réservé est en réalité un tf.constant()
dont le nom par défaut est "Const:0"
, ce qui explique pourquoi l'indicateur par défaut est "save/Const:0"
.
J'ai remarqué qu'une erreur m'était arrivée quand j'avais arrangé le code comme ceci:
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '', '/tmp/train.pbtxt')
init = tf.initialize_all_variables()
saver = tf.train.Saver()
sess.run(init)
Cela a fonctionné après avoir changé la disposition du code comme ceci:
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
init = tf.initialize_all_variables()
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '', '/tmp/train.pbtxt')
sess.run(init)
Je ne sais pas trop pourquoi. @mrry pourriez-vous expliquer un peu plus?