web-dev-qa-db-fra.com

Liste des noms de tenseurs dans le graphique dans Tensorflow

L'objet graphique dans Tensorflow a une méthode appelée "get_tensor_by_name (name)". Est-il possible d'obtenir une liste de noms de tenseurs valides?

Si non, est-ce que quelqu'un connaît les noms valides pour le modèle pré-entraîné inception-v3 à partir d'ici ? Pool_3, par exemple, est un tenseur valide, mais une liste de tous serait Nice. J'ai jeté un œil sur le document mentionné et certaines des couches semblent correspondre aux tailles du tableau 1, mais pas à toutes.

52
John Pertoft

Le papier ne reflète pas exactement le modèle. Si vous téléchargez le code source depuis arxiv, il contient une description de modèle précise, modèle.txt, et les noms qui y figurent sont fortement corrélés avec les noms du modèle publié.

Pour répondre à votre première question, sess.graph.get_operations() vous donne une liste d'opérations. Pour un op, op.name Vous donne le nom et op.values() vous donne une liste des tenseurs qu'il produit (dans le modèle inception-v3, tous les noms de tenseur sont le nom op avec un ": 0 "ajouté à celui-ci, donc pool_3:0 est le tenseur produit par l'opération de mise en commun finale.

55
etarion

Pour voir les opérations dans le graphique (vous en verrez beaucoup, alors pour abréger, je n’ai donné ici que la première chaîne).

sess = tf.Session()
op = sess.graph.get_operations()
[m.values() for m in op][1]

out:
(<tf.Tensor 'conv1/weights:0' shape=(4, 4, 3, 32) dtype=float32_ref>,)
23
Prakash Vanapalli

Les réponses ci-dessus sont correctes. Je suis tombé sur un code facile à comprendre/simple pour la tâche ci-dessus. Donc, le partager ici: -

import tensorflow as tf

def printTensors(pb_file):

    # read pb into graph_def
    with tf.gfile.GFile(pb_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # import graph_def
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

    # print operations
    for op in graph.get_operations():
        print(op.name)


printTensors("path-to-my-pbfile.pb")
20
Akash Goyal

Il n'est même pas nécessaire de créer une session pour voir les noms de tous les noms d'opération dans le graphique. Pour ce faire, il vous suffit de saisir un graphe par défaut tf.get_default_graph() et d'extraire toutes les opérations: .get_operations . Chaque opération a plusieurs champs , celle dont vous avez besoin est name.

Voici le code:

import tensorflow as tf
a = tf.Variable(5)
b = tf.Variable(6)
c = tf.Variable(7)
d = (a + b) * c

for i in tf.get_default_graph().get_operations():
    print i.name
8
Salvador Dali

En tant que compréhension de liste imbriquée:

tensor_names = [t.name for op in tf.get_default_graph().get_operations() for t in op.values()]

Fonction pour obtenir les noms des tenseurs dans un graphique (par défaut, le graphique par défaut):

def get_names(graph=tf.get_default_graph()):
    return [t.name for op in graph.get_operations() for t in op.values()]

Fonction pour obtenir les tenseurs dans un graphique (par défaut, le graphique par défaut):

def get_tensors(graph=tf.get_default_graph()):
    return [t for op in graph.get_operations() for t in op.values()]
2
eqzx