web-dev-qa-db-fra.com

Comment optimiser pour l'inférence un graphique simple et enregistré TensorFlow 1.0.1?

Je ne parviens pas à exécuter le module optimize_for_inference sur un graphique TensorFlow simple et enregistré (package Python 2.7 installé par pip install tensorflow-gpu==1.0.1).

Contexte

Enregistrement du graphique TensorFlow

Voici mon script Python pour générer et enregistrer un graphique simple afin d'ajouter 5 à mon opération d'entrée xplaceholder.

import tensorflow as tf

# make and save a simple graph
G = tf.Graph()
with G.as_default():
    x = tf.placeholder(dtype=tf.float32, shape=(), name="x")
    a = tf.Variable(5.0, name="a")
    y = tf.add(a, x, name="y")
    saver = tf.train.Saver()

with tf.Session(graph=G) as sess:
    sess.run(tf.global_variables_initializer())
    out = sess.run(fetches=[y], feed_dict={x: 1.0})
    print(out)
    saver.save(sess=sess, save_path="test_model")

Restauration du graphique TensorFlow

J'ai un script de restauration simple qui recrée le graphe enregistré et restaure les paramètres du graphe. Les scripts de sauvegarde/restauration produisent le même résultat.

import tensorflow as tf

# Restore simple graph and test model output
G = tf.Graph()

with tf.Session(graph=G) as sess:
    # recreate saved graph (structure)
    saver = tf.train.import_meta_graph('./test_model.meta')
    # restore net params
    saver.restore(sess, tf.train.latest_checkpoint('./'))

    x = G.get_operation_by_name("x").outputs[0]
    y = G.get_operation_by_name("y").outputs
    out = sess.run(fetches=[y], feed_dict={x: 1.0})
    print(out[0])

Tentative d'optimisation

Mais, bien que je n'attende pas grand-chose en termes d'optimisation, lorsque j'essaie d'optimiser le graphique pour l'inférence, le message d'erreur suivant s'affiche. Le noeud de sortie attendu ne semble pas figurer dans le graphique enregistré.

$ python -m tensorflow.python.tools.optimize_for_inference --input test_model.data-00000-of-00001 --output opt_model --input_names=x --output_names=y  
Traceback (most recent call last):  
  File "/usr/lib/python2.7/runpy.py", line 174, in _run_module_as_main  
    "__main__", fname, loader, pkg_name)  
  File "/usr/lib/python2.7/runpy.py", line 72, in _run_code  
    exec code in run_globals  
  File "/{path}/lib/python2.7/site-packages/tensorflow/python/tools/optimize_for_inference.py", line 141, in <module>  
    app.run(main=main, argv=[sys.argv[0]] + unparsed)  
  File "/{path}/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run  
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/{path}/lib/python2.7/site-packages/tensorflow/python/tools/optimize_for_inference.py", line 90, in main  
    FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)  
  File "/{path}/local/lib/python2.7/site-packages/tensorflow/python/tools/optimize_for_inference_lib.py", line 91, in optimize_for_inference  
    placeholder_type_enum)  
  File "/{path}/local/lib/python2.7/site-packages/tensorflow/python/tools/strip_unused_lib.py", line 71, in strip_unused  
    output_node_names)  
  File "/{path}/local/lib/python2.7/site-packages/tensorflow/python/framework/graph_util_impl.py", line 141, in extract_sub_graph  
    assert d in name_to_node_map, "%s is not in graph" % d  
AssertionError: y is not in graph  

Une enquête plus approfondie m'a conduit à inspecter le point de contrôle du graphe sauvegardé, qui ne montre qu'un tenseur (a, no x et no y).

(tf-1.0.1) $ python -m tensorflow.python.tools.inspect_checkpoint --file_name ./test_model --all_tensors
tensor_name:  a
5.0

Questions spécifiques

  1. Pourquoi ne vois-je pas x et y dans le point de contrôle? Est-ce parce qu'ils sont des opérations et non des tenseurs?
  2. Comme je dois fournir des noms d'entrée et de sortie au module optimize_for_inference, comment puis-je construire le graphique pour pouvoir référencer les nœuds d'entrée et de sortie?
15
tdube

Voici le guide détaillé sur la façon d'optimiser l'inférence:

Le module optimize_for_inference prend un fichier frozen binary GraphDef en entrée et génère le fichier optimized Graph Def que vous pouvez utiliser pour l'inférence. Et pour obtenir le frozen binary GraphDef file, vous devez utiliser le module freeze_graph qui prend un GraphDef proto, un SaverDef proto et un ensemble de variables stockées dans un fichier de point de contrôle. Les étapes pour y parvenir sont indiquées ci-dessous:

1. Sauvegarder le graphe tensorflow

 # make and save a simple graph
 G = tf.Graph()
 with G.as_default():
   x = tf.placeholder(dtype=tf.float32, shape=(), name="x")
   a = tf.Variable(5.0, name="a")
   y = tf.add(a, x, name="y")
   saver = tf.train.Saver()

with tf.Session(graph=G) as sess:
   sess.run(tf.global_variables_initializer())
   out = sess.run(fetches=[y], feed_dict={x: 1.0})

  # Save GraphDef
  tf.train.write_graph(sess.graph_def,'.','graph.pb')
  # Save checkpoint
  saver.save(sess=sess, save_path="test_model")

2. Geler le graphique

python -m tensorflow.python.tools.freeze_graph --input_graph graph.pb --input_checkpoint test_model --output_graph graph_frozen.pb --output_node_names=y

3. Optimiser pour l'inférence

python -m tensorflow.python.tools.optimize_for_inference --input graph_frozen.pb --output graph_optimized.pb --input_names=x --output_names=y

4. Utilisation du graphique optimisé

with tf.gfile.GFile('graph_optimized.pb', 'rb') as f:
   graph_def_optimized = tf.GraphDef()
   graph_def_optimized.ParseFromString(f.read())

G = tf.Graph()

with tf.Session(graph=G) as sess:
    y, = tf.import_graph_def(graph_def_optimized, return_elements=['y:0'])
    print('Operations in Optimized Graph:')
    print([op.name for op in G.get_operations()])
    x = G.get_tensor_by_name('import/x:0')
    out = sess.run(y, feed_dict={x: 1.0})
    print(out)

#Output
#Operations in Optimized Graph:
#['import/x', 'import/a', 'import/y']
#6.0

5. Pour plusieurs noms de sortie

S'il existe plusieurs noeuds de sortie, spécifiez: output_node_names = 'boxes, scores, classes' et importez le graphique avec,

 boxes,scores,classes, = tf.import_graph_def(graph_def_optimized, return_elements=['boxes:0', 'scores:0', 'classes:0'])
33
vijay m
  1. Vous vous trompez: input est un fichier graphdef pour le script pas la partie données du point de contrôle. Vous devez geler le modèle dans un fichier .pb/ou obtenir le prototxt pour le graphique et utiliser le script optimiser pour l'inférence. 

This script takes either a frozen binary GraphDef file (where the weight variables have been converted into constants by the freeze_graph script), or a text GraphDef proto file (the weight variables are stored in a separate checkpoint file), and outputs a new GraphDef with the optimizations applied.

  1. Récupère le fichier proto du graphique avec write_graph
  2. récupère le modèle figé gel graphique
0
Ishant Mrinal