J'ai du mal à récupérer un tenseur par son nom, je ne sais même pas si c'est possible.
J'ai une fonction qui crée mon graphique:
def create_structure(tf, x, input_size,dropout):
with tf.variable_scope("scale_1") as scope:
W_S1_conv1 = deep_dive.weight_variable_scaling([7,7,3,64], name='W_S1_conv1')
b_S1_conv1 = deep_dive.bias_variable([64])
S1_conv1 = tf.nn.relu(deep_dive.conv2d(x_image, W_S1_conv1,strides=[1, 2, 2, 1], padding='SAME') + b_S1_conv1, name="Scale1_first_relu")
.
.
.
return S3_conv1,regularizer
Je veux accéder à la variable S1_conv1 en dehors de cette fonction. J'ai essayé:
with tf.variable_scope('scale_1') as scope_conv:
tf.get_variable_scope().reuse_variables()
ft=tf.get_variable('Scale1_first_relu')
Mais cela me donne une erreur:
ValueError: Sous-partage: La variable scale_1/Scale1_first_relu n'existe pas, non autorisée. Voulez-vous dire que reuse = Aucun dans VarScope?
Mais cela fonctionne:
with tf.variable_scope('scale_1') as scope_conv:
tf.get_variable_scope().reuse_variables()
ft=tf.get_variable('W_S1_conv1')
Je peux contourner ça avec
return S3_conv1,regularizer, S1_conv1
mais je ne veux pas faire ça.
Je pense que mon problème est que S1_conv1 n'est pas vraiment une variable, c'est juste un tenseur. Y a-t-il un moyen de faire ce que je veux?
Il existe une fonction tf.Graph.get_tensor_by_name (). Par exemple:
import tensorflow as tf
c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d, name='example')
with tf.Session() as sess:
test = sess.run(e)
print e.name #example:0
test = tf.get_default_graph().get_tensor_by_name("example:0")
print test #Tensor("example:0", shape=(2, 2), dtype=float32)
Tous les tenseurs ont des noms de chaîne que vous pouvez voir comme suit
[tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
Une fois que vous connaissez le nom, vous pouvez récupérer le Tenseur à l’aide de <name>:0
(0 correspond au point final qui est quelque peu redondant)
Par exemple si vous faites cela
tf.constant(1)+tf.constant(2)
Vous avez les noms de tenseur suivants
[u'Const', u'Const_1', u'add']
Donc, vous pouvez aller chercher la sortie de l'addition
sess.run('add:0')
Notez que cela ne fait pas partie de l'API publique. Les noms de tenseur de chaînes générés automatiquement constituent un détail de la mise en œuvre et peuvent changer.
Tout ce que vous devez faire dans ce cas est:
ft=tf.get_variable('scale1/Scale1_first_relu:0')