J'ai un réseau de neurones, à partir d'un générateur de données tf.data
et d'un modèle tf.keras
, comme suit (une version simplifiée, car elle serait trop longue):
dataset = ...
Un objet tf.data.Dataset
qui avec la méthode next_x
appelle le get_next
pour l'itérateur x_train
et pour la méthode next_y
appelle le get_next
pour l'itérateur y_train
. Chaque étiquette est un tableau (1, 67)
au format one-hot.
Couches:
input_tensor = tf.keras.layers.Input(shape=(240, 240, 3)) # dim of x
output = tf.keras.layers.Flatten()(input_tensor)
output= tf.keras.Dense(67, activation='softmax')(output) # 67 is the number of classes
Modèle:
model = tf.keras.models.Model(inputs=input_tensor, outputs=prediction)
model.compile(optimizer=tf.train.AdamOptimizer(), loss=tf.losses.softmax_cross_entropy, metrics=['accuracy'])
model.fit_generator(gen(dataset.next_x(), dataset.next_y()), steps_per_epochs=100)
gen
est défini comme ceci:
def gen(x, y):
while True:
yield(x, y)
Mon problème est que lorsque j'essaye de l'exécuter, j'obtiens une erreur dans la partie model.fit
:
ValueError: Cannot take the length of Shape with unknown rank.
Toutes les idées sont appréciées!
Pourriez-vous poster une trace de pile plus longue? Je pense que votre problème pourrait être lié à ce problème récent de tensorflow:
https://github.com/tensorflow/tensorflow/issues/24520
Il existe également un simple PR qui le corrige (pas encore fusionné). Peut-être l'essayer vous-même?
MODIFIER
Voici le PR: Open tensorflow/python/keras/engine/training_utils.py
remplacez ce qui suit (ligne 232 pour le moment):
if (x.shape is not None
and len(x.shape) == 1
avec ça:
if tensor_util.is_tensor(x):
x_shape_ndims = x.shape.ndims if x.shape is not None else None
else:
x_shape_ndims = len(x.shape)
if (x_shape_ndims == 1
J'ai découvert ce qui n'allait pas. En fait, je dois run
le prochain lot dans un tf.Session
avant de le céder ... Voici comment cela fonctionne (je n'écris pas le reste du code car il reste identique):
model.fit_generator(gen(), steps_per_epochs=100)
def gen():
with tf.Session() as sess:
next_x = dataset.next_x()
next_y = dataset.next_y()
while True:
x_batch = sess.run(next_x)
y_batch = sess.run(next_y)
yield x_batch, y_batch