Je forme un CNN assez similaire à celui de l'exemple this , pour la segmentation d'image. Les images sont 1500x1500x1 et les étiquettes sont de la même taille.
Après avoir défini la structure CNN et lancé la session comme dans cet exemple de code: (conv_net_test.py
)
with tf.Session() as sess:
sess.run(init)
summ = tf.train.SummaryWriter('/tmp/logdir/', sess.graph_def)
step = 1
print ("import data, read from read_data_sets()...")
#Data defined by me, returns a DataSet object with testing and training images and labels for segmentation problem.
data = import_data_test.read_data_sets('Dataset')
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_x, batch_y = data.train.next_batch(batch_size)
print ("running backprop for step %d" % step)
batch_x = batch_x.reshape(batch_size, n_input, n_input, n_channels)
batch_y = batch_y.reshape(batch_size, n_input, n_input, n_channels)
batch_y = np.int64(batch_y)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y, keep_prob: dropout})
if step % display_step == 0:
# Calculate batch loss and accuracy
#pdb.set_trace()
loss, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
step += 1
print "Optimization Finished"
J'ai rencontré la TypeError suivante (stacktrace ci-dessous):
conv_net_test.py in <module>()
178 #pdb.set_trace()
--> 179 loss, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
180 step += 1
181 print "Optimization Finished!"
tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
370 try:
371 result = self._run(None, fetches, feed_dict, options_ptr,
--> 372 run_metadata_ptr)
373 if run_metadata:
374 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
582
583 # Validate and process fetches.
--> 584 processed_fetches = self._process_fetches(fetches)
585 unique_fetches = processed_fetches[0]
586 target_list = processed_fetches[1]
tensorflow/python/client/session.pyc in _process_fetches(self, fetches)
538 raise TypeError('Fetch argument %r of %r has invalid type %r, '
539 'must be a string or Tensor. (%s)'
--> 540 % (subfetch, fetch, type(subfetch), str(e)))
TypeError: Fetch argument 1.4415792e+2 of 1.4415792e+2 has invalid type <type 'numpy.float32'>, must be a string or Tensor. (Can not convert a float32 into a Tensor or Operation.)
Je suis perplexe à ce stade. C'est peut-être un cas simple de conversion du type, mais je ne sais pas comment/où. Aussi, pourquoi la perte doit-elle être une chaîne? (En supposant que la même erreur apparaîtra également pour la précision, une fois cela corrigé).
Toute aide appréciée!
Lorsque vous utilisez loss = sess.run(loss)
, vous redéfinissez dans python la variable loss
.
La première fois, il fonctionnera bien. La deuxième fois, vous essaierez de faire:
sess.run(1.4415792e+2)
Parce que loss
est maintenant un flottant.
Vous devez utiliser des noms différents comme:
loss_val, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})