Est-il possible d'obtenir le nombre total d'enregistrements à partir d'un .tfrecords
fichier ? À cet égard, comment peut-on généralement suivre le nombre d'époques qui se sont écoulées lors de la formation des modèles? S'il nous est possible de spécifier le batch_size
et num_of_epochs
, Je ne sais pas s'il est simple d'obtenir des valeurs telles que current Epoch
, nombre de lots par époque, etc. - pour que je puisse mieux contrôler la progression de la formation. Actuellement, j'utilise simplement un hack sale pour calculer cela car je sais à l'avance combien d'enregistrements il y a dans mon fichier .tfrecords et la taille de mes minibatches. Appréciez toute aide ..
Pour compter le nombre d'enregistrements, vous devriez pouvoir utiliser tf.python_io.tf_record_iterator
.
c = 0
for fn in tf_records_filenames:
for record in tf.python_io.tf_record_iterator(fn):
c += 1
Pour garder une trace de la formation du modèle, tensorboard est très pratique.
Non ce n'est pas possible. TFRecord ne stocke aucune métadonnée sur les données stockées à l'intérieur. Ce fichier
représente une séquence de chaînes (binaires). Le format n'est pas un accès aléatoire, il convient donc pour le streaming de grandes quantités de données mais ne convient pas si un partage rapide ou un autre accès non séquentiel est souhaité.
Si vous le souhaitez, vous pouvez stocker ces métadonnées manuellement ou utiliser un record_iterator pour obtenir le nombre (vous devrez parcourir tous les enregistrements que vous avez:
sum(1 for _ in tf.python_io.tf_record_iterator(file_name))
Si vous voulez connaître l'époque actuelle, vous pouvez le faire à partir du tensorboard ou en imprimant le numéro à partir de la boucle.
Selon l'avertissement de dépréciation du tf_record_iterator , nous pouvons également utiliser une exécution rapide pour compter les enregistrements.
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import sys
assert len(sys.argv) == 2, \
"USAGE: {} <file_glob>".format(sys.argv[0])
tf.enable_eager_execution()
input_pattern = sys.argv[1]
# This is where we get the count of records
records_n = sum(1 for record in tf.data.TFRecordDataset(tf.gfile.Glob(input_pattern)))
print("records_n = {}".format(records_n))