web-dev-qa-db-fra.com

Obtention du nombre total d'enregistrements à partir du fichier .tfrecords dans Tensorflow

Est-il possible d'obtenir le nombre total d'enregistrements à partir d'un .tfrecords fichier ? À cet égard, comment peut-on généralement suivre le nombre d'époques qui se sont écoulées lors de la formation des modèles? S'il nous est possible de spécifier le batch_size et num_of_epochs, Je ne sais pas s'il est simple d'obtenir des valeurs telles que current Epoch, nombre de lots par époque, etc. - pour que je puisse mieux contrôler la progression de la formation. Actuellement, j'utilise simplement un hack sale pour calculer cela car je sais à l'avance combien d'enregistrements il y a dans mon fichier .tfrecords et la taille de mes minibatches. Appréciez toute aide ..

24
HuckleberryFinn

Pour compter le nombre d'enregistrements, vous devriez pouvoir utiliser tf.python_io.tf_record_iterator .

c = 0
for fn in tf_records_filenames:
  for record in tf.python_io.tf_record_iterator(fn):
     c += 1

Pour garder une trace de la formation du modèle, tensorboard est très pratique.

30
drpng

Non ce n'est pas possible. TFRecord ne stocke aucune métadonnée sur les données stockées à l'intérieur. Ce fichier

représente une séquence de chaînes (binaires). Le format n'est pas un accès aléatoire, il convient donc pour le streaming de grandes quantités de données mais ne convient pas si un partage rapide ou un autre accès non séquentiel est souhaité.

Si vous le souhaitez, vous pouvez stocker ces métadonnées manuellement ou utiliser un record_iterator pour obtenir le nombre (vous devrez parcourir tous les enregistrements que vous avez:

sum(1 for _ in tf.python_io.tf_record_iterator(file_name))

Si vous voulez connaître l'époque actuelle, vous pouvez le faire à partir du tensorboard ou en imprimant le numéro à partir de la boucle.

17
Salvador Dali

Selon l'avertissement de dépréciation du tf_record_iterator , nous pouvons également utiliser une exécution rapide pour compter les enregistrements.

#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import sys

assert len(sys.argv) == 2, \
    "USAGE: {} <file_glob>".format(sys.argv[0])

tf.enable_eager_execution()

input_pattern = sys.argv[1]

# This is where we get the count of records
records_n = sum(1 for record in tf.data.TFRecordDataset(tf.gfile.Glob(input_pattern)))

print("records_n = {}".format(records_n))
2
Russell