web-dev-qa-db-fra.com

tf.reshape vs tf.contrib.layers.flatten

Je lance donc un CNN pour un problème de classification. J'ai 3 couches conv avec 3 couches de mise en commun. P3 est la sortie de la dernière couche de regroupement, dont les dimensions sont: [Batch_size, 4, 12, 48], et je veux aplatir cette matrice en [Batch_size, 2304] matrice de taille, étant 2304 = 4 * 12 * 48 . Je travaillais avec "Option A" (voir ci-dessous) depuis un moment, mais un jour j'ai voulu essayer "Option B", ce qui me donnerait théoriquement le même résultat. Mais ce ne fut pas le cas. J'ai déjà lu le fil suivant

tf.contrib.layers.flatten (x) est-il le même que tf.reshape (x, [n, 1])?

mais cela n'a fait qu'ajouter plus de confusion, car essayer "Option C" (tiré du fil susmentionné) a donné un nouveau résultat différent.

P3 = tf.nn.max_pool(A3, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding='VALID')

P3_shape = P3.get_shape().as_list()

P = tf.contrib.layers.flatten(P3)                             <-----Option A

P = tf.reshape(P3, [-1, P3_shape[1]*P3_shape[2]*P3_shape[3]]) <---- Option B

P = tf.reshape(P3, [tf.shape(P3)[0], -1])                     <---- Option C

Je suis plus enclin à opter pour "Option B" car c'est celle que j'ai vue dans une vidéo de Dandelion Mane ( https://www.youtube.com/watch?v=eBbEDRsCmv4&t=631s ), mais j'aimerais comprendre pourquoi ces 3 options donnent des résultats différents.

8
sdiabr

Les 3 options se reforment de manière identique:

import tensorflow as tf
import numpy as np

p3 = tf.placeholder(tf.float32, [None, 1, 2, 4])

p3_shape = p3.get_shape().as_list()

p_a = tf.contrib.layers.flatten(p3)                                  # <-----Option A

p_b = tf.reshape(p3, [-1, p3_shape[1] * p3_shape[2] * p3_shape[3]])  # <---- Option B

p_c = tf.reshape(p3, [tf.shape(p3)[0], -1])                          # <---- Option C

print(p_a.get_shape())
print(p_b.get_shape())
print(p_c.get_shape())

with tf.Session() as sess:
    i_p3 = np.arange(16, dtype=np.float32).reshape([2, 1, 2, 4])
    print("a", sess.run(p_a, feed_dict={p3: i_p3}))
    print("b", sess.run(p_b, feed_dict={p3: i_p3}))
    print("c", sess.run(p_c, feed_dict={p3: i_p3}))

Ce code donne le même résultat 3 fois. Vos différents résultats sont causés par autre chose et non par le remodelage.

9
BlueSun