J'ai une question de base sur la façon d'indexer dans TensorFlow.
En numpy:
x = np.asarray([1,2,3,3,2,5,6,7,1,3])
e = np.asarray([0,1,0,1,1,1,0,1])
#numpy
print x * e[x]
Je peux obtenir
[1 0 3 3 0 5 0 7 1 3]
Comment puis-je faire cela dans TensorFlow?
x = np.asarray([1,2,3,3,2,5,6,7,1,3])
e = np.asarray([0,1,0,1,1,1,0,1])
x_t = tf.constant(x)
e_t = tf.constant(e)
with tf.Session():
????
Merci!
Heureusement, le cas exact dont vous parlez est pris en charge dans TensorFlow par tf.gather()
:
result = x_t * tf.gather(e_t, x_t)
with tf.Session() as sess:
print sess.run(result) # ==> 'array([1, 0, 3, 3, 0, 5, 0, 7, 1, 3])'
L'opération tf.gather()
est moins puissante que indexation avancée de NumPy : elle ne prend en charge que l'extraction de tranches complètes d'un tenseur sur sa dimension 0. La prise en charge d'une indexation plus générale a été demandée et est suivie dans ce problème GitHub .