web-dev-qa-db-fra.com

Tensorflow: Comment obtenir un tenseur par son nom?

J'ai du mal à récupérer un tenseur par son nom, je ne sais même pas si c'est possible.

J'ai une fonction qui crée mon graphique:

def create_structure(tf, x, input_size,dropout):    
 with tf.variable_scope("scale_1") as scope:
  W_S1_conv1 = deep_dive.weight_variable_scaling([7,7,3,64], name='W_S1_conv1')
  b_S1_conv1 = deep_dive.bias_variable([64])
  S1_conv1 = tf.nn.relu(deep_dive.conv2d(x_image, W_S1_conv1,strides=[1, 2, 2, 1], padding='SAME') + b_S1_conv1, name="Scale1_first_relu")
.
.
.
return S3_conv1,regularizer

Je veux accéder à la variable S1_conv1 en dehors de cette fonction. J'ai essayé:

with tf.variable_scope('scale_1') as scope_conv: 
 tf.get_variable_scope().reuse_variables()
 ft=tf.get_variable('Scale1_first_relu')

Mais cela me donne une erreur:

ValueError: Sous-partage: La variable scale_1/Scale1_first_relu n'existe pas, non autorisée. Voulez-vous dire que reuse = Aucun dans VarScope?

Mais cela fonctionne:

with tf.variable_scope('scale_1') as scope_conv: 
 tf.get_variable_scope().reuse_variables()
 ft=tf.get_variable('W_S1_conv1')

Je peux contourner ça avec

return S3_conv1,regularizer, S1_conv1

mais je ne veux pas faire ça.

Je pense que mon problème est que S1_conv1 n'est pas vraiment une variable, c'est juste un tenseur. Y a-t-il un moyen de faire ce que je veux?

41
protas

Il existe une fonction tf.Graph.get_tensor_by_name (). Par exemple:

import tensorflow as tf

c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d, name='example')

with tf.Session() as sess:
    test =  sess.run(e)
    print e.name #example:0
    test = tf.get_default_graph().get_tensor_by_name("example:0")
    print test #Tensor("example:0", shape=(2, 2), dtype=float32)
49
apfalz

Tous les tenseurs ont des noms de chaîne que vous pouvez voir comme suit

[tensor.name for tensor in tf.get_default_graph().as_graph_def().node]

Une fois que vous connaissez le nom, vous pouvez récupérer le Tenseur à l’aide de <name>:0 (0 correspond au point final qui est quelque peu redondant)

Par exemple si vous faites cela

tf.constant(1)+tf.constant(2)

Vous avez les noms de tenseur suivants

[u'Const', u'Const_1', u'add']

Donc, vous pouvez aller chercher la sortie de l'addition

sess.run('add:0')

Notez que cela ne fait pas partie de l'API publique. Les noms de tenseur de chaînes générés automatiquement constituent un détail de la mise en œuvre et peuvent changer.

30
Yaroslav Bulatov

Tout ce que vous devez faire dans ce cas est:

ft=tf.get_variable('scale1/Scale1_first_relu:0')
0
Kislay Kunal