J'ai quelques difficultés à utiliser la fonction accuracy
de tf.metrics
pour un problème de classification multiple avec des logits en entrée.
Ma sortie de modèle ressemble à:
logits = [[0.1, 0.5, 0.4],
[0.8, 0.1, 0.1],
[0.6, 0.3, 0.2]]
Et mes étiquettes sont des vecteurs encodés à chaud:
labels = [[0, 1, 0],
[1, 0, 0],
[0, 0, 1]]
Lorsque j'essaie de faire quelque chose comme tf.metrics.accuracy(labels, logits)
, cela ne donne jamais le résultat correct. Je fais évidemment quelque chose de mal, mais je ne peux pas comprendre ce que c'est.
TL; DR
La fonction de précision tf.metrics.accuracy calcule la fréquence à laquelle les prédictions correspondent aux étiquettes en fonction de deux variables locales créées: total
et count
, utilisées pour calculer la fréquence avec laquelle logits
correspond à labels
.
acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1),
predictions=tf.argmax(logits,1))
print(sess.run([acc, acc_op]))
print(sess.run([acc]))
# Output
#[0.0, 0.66666669]
#[0.66666669]
total
et count
, ne met pas à jour les métriques.Pour comprendre pourquoi acc renvoie 0.0
, consultez les détails ci-dessous.
Détails à l'aide d'un exemple simple:
logits = tf.placeholder(tf.int64, [2,3])
labels = tf.Variable([[0, 1, 0], [1, 0, 1]])
acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1),
predictions=tf.argmax(logits,1))
Initialise les variables:
Puisque metrics.accuracy
crée deux variables locales total
et count
, nous devons appeler local_variables_initializer()
pour les initialiser.
sess = tf.Session()
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
stream_vars = [i for i in tf.local_variables()]
print(stream_vars)
#[<tf.Variable 'accuracy/total:0' shape=() dtype=float32_ref>,
# <tf.Variable 'accuracy/count:0' shape=() dtype=float32_ref>]
Comprendre les opérations de mise à jour et le calcul de la précision:
print('acc:',sess.run(acc, {logits:[[0,1,0],[1,0,1]]}))
#acc: 0.0
print('[total, count]:',sess.run(stream_vars))
#[total, count]: [0.0, 0.0]
Les valeurs ci-dessus renvoient 0,0 pour l'exactitude, car total
et count
sont des zéros, malgré le fait de donner des entrées correspondantes.
print('ops:', sess.run(acc_op, {logits:[[0,1,0],[1,0,1]]}))
#ops: 1.0
print('[total, count]:',sess.run(stream_vars))
#[total, count]: [2.0, 2.0]
Avec les nouvelles entrées, la précision est calculée lorsque l’opération de mise à jour est appelée. Remarque: étant donné que tous les logits et étiquettes correspondent, nous obtenons une précision de 1,0 et les variables locales total
et count
donnent en fait total correctly predicted
et le total comparisons made
.
Nous appelons maintenant accuracy
avec les nouvelles entrées (pas les opérations de mise à jour):
print('acc:', sess.run(acc,{logits:[[1,0,0],[0,1,0]]}))
#acc: 1.0
L'appel d'exactitude ne met pas à jour les métriques avec les nouvelles entrées, il renvoie simplement la valeur en utilisant les deux variables locales. Remarque: les logits et les étiquettes ne correspondent pas dans ce cas. Maintenant, appelez à nouveau les opérations de mise à jour:
print('op:',sess.run(acc_op,{logits:[[0,1,0],[0,1,0]]}))
#op: 0.75
print('[total, count]:',sess.run(stream_vars))
#[total, count]: [3.0, 4.0]
Les métriques sont mises à jour pour les nouvelles entrées
Pour plus d'informations sur l'utilisation des métriques lors de la formation et sur leur réinitialisation lors de la validation, reportez-vous à ici .
Appliqué sur un CNN, vous pouvez écrire:
x_len=24*24
y_len=2
x = tf.placeholder(tf.float32, shape=[None, x_len], name='input')
fc1 = ... # cnn's fully connected layer
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
layer_fc_dropout = tf.nn.dropout(fc1, keep_prob, name='dropout')
y_pred = tf.nn.softmax(fc1, name='output')
logits = tf.argmax(y_pred, axis=1)
y_true = tf.placeholder(tf.float32, shape=[None, y_len], name='y_true')
acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(y_true, axis=1), predictions=tf.argmax(y_pred, 1))
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
def print_accuracy(x_data, y_data, dropout=1.0):
accuracy = sess.run(acc_op, feed_dict = {y_true: y_data, x: x_data, keep_prob: dropout})
print('Accuracy: ', accuracy)