web-dev-qa-db-fra.com

TensorFlow - Lire tous les exemples d'un TFRecords à la fois?

Comment lisez-vous tous les exemples d'un enregistrement TFR à la fois?

J'ai utilisé tf.parse_single_example Pour lire des exemples individuels en utilisant un code similaire à celui donné dans la méthode read_and_decode Dans exemple de Fully_connected_reader . Toutefois, je souhaite exécuter le réseau simultanément sur l'intégralité de mon ensemble de données de validation. J'aimerais donc les charger entièrement.

Je ne suis pas tout à fait sûr, mais la documentation semble suggérer que je peux utiliser tf.parse_example Au lieu de tf.parse_single_example Pour charger l'intégralité du fichier TFRecords en même temps. Je n'arrive pas à obtenir que cela fonctionne. J'imagine que cela dépend de la façon dont je spécifie les fonctionnalités, mais je ne sais pas comment, dans la spécification de la fonctionnalité, indiquer qu'il existe plusieurs exemples.

En d'autres termes, ma tentative d'utiliser quelque chose de similaire à:

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_example(serialized_example, features={
    'image_raw': tf.FixedLenFeature([], tf.string),
    'label': tf.FixedLenFeature([], tf.int64),
})

ne fonctionne pas, et je suppose que c'est parce que les fonctionnalités n'attendent pas plusieurs exemples à la fois (mais encore une fois, je ne suis pas sûr). [Ceci entraîne une erreur de ValueError: Shape () must have rank 1]

Est-ce la bonne façon de lire tous les enregistrements à la fois? Et si oui, que dois-je changer pour pouvoir réellement lire les enregistrements? Merci beaucoup!

31
golmschenk

Juste pour la clarté, j'ai quelques milliers d’images dans un seul fichier .tfrecords, ce sont des fichiers png de 720 sur 720. Les étiquettes sont l'une des 0,1,2,3.

J'ai également essayé d'utiliser parse_example et je n'ai pas réussi à le faire fonctionner, mais cette solution fonctionne avec parse_single_example.

L’inconvénient est qu’à présent, je dois savoir combien d’éléments sont présents dans chaque enregistrement .tf, ce qui est plutôt décevant. Si je trouve un meilleur moyen, je mettrai à jour la réponse. Veillez également à ne pas dépasser le nombre d'enregistrements contenus dans le fichier .tfrecords. Il recommencera au premier enregistrement si vous passez au-dessus du dernier.

L'astuce consistait à faire utiliser le coordinateur par le coureur de la file d'attente.

J'ai laissé du code ici pour enregistrer les images au fur et à mesure de leur lecture afin que vous puissiez vérifier que l'image est correcte.

from PIL import Image
import numpy as np
import tensorflow as tf

def read_and_decode(filename_queue):
 reader = tf.TFRecordReader()
 _, serialized_example = reader.read(filename_queue)
 features = tf.parse_single_example(
  serialized_example,
  # Defaults are not specified since both keys are required.
  features={
      'image_raw': tf.FixedLenFeature([], tf.string),
      'label': tf.FixedLenFeature([], tf.int64),
      'height': tf.FixedLenFeature([], tf.int64),
      'width': tf.FixedLenFeature([], tf.int64),
      'depth': tf.FixedLenFeature([], tf.int64)
  })
 image = tf.decode_raw(features['image_raw'], tf.uint8)
 label = tf.cast(features['label'], tf.int32)
 height = tf.cast(features['height'], tf.int32)
 width = tf.cast(features['width'], tf.int32)
 depth = tf.cast(features['depth'], tf.int32)
 return image, label, height, width, depth


def get_all_records(FILE):
 with tf.Session() as sess:
   filename_queue = tf.train.string_input_producer([ FILE ])
   image, label, height, width, depth = read_and_decode(filename_queue)
   image = tf.reshape(image, tf.pack([height, width, 3]))
   image.set_shape([720,720,3])
   init_op = tf.initialize_all_variables()
   sess.run(init_op)
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(coord=coord)
   for i in range(2053):
     example, l = sess.run([image, label])
     img = Image.fromarray(example, 'RGB')
     img.save( "output/" + str(i) + '-train.png')

     print (example,l)
   coord.request_stop()
   coord.join(threads)

get_all_records('/path/to/train-0.tfrecords')
20
Andrew Pierno

Pour lire toutes les données une seule fois, vous devez passer num_epochs au string_input_producer. Lorsque tous les enregistrements sont lus, le .read méthode de lecteur va jeter une erreur, que vous pouvez attraper. Exemple simplifié:

import tensorflow as tf

def read_and_decode(filename_queue):
 reader = tf.TFRecordReader()
 _, serialized_example = reader.read(filename_queue)
 features = tf.parse_single_example(
  serialized_example,
  features={
      'image_raw': tf.FixedLenFeature([], tf.string)
  })
 image = tf.decode_raw(features['image_raw'], tf.uint8)
 return image


def get_all_records(FILE):
 with tf.Session() as sess:
   filename_queue = tf.train.string_input_producer([FILE], num_epochs=1)
   image = read_and_decode(filename_queue)
   init_op = tf.initialize_all_variables()
   sess.run(init_op)
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(coord=coord)
   try:
     while True:
       example = sess.run([image])
   except tf.errors.OutOfRangeError, e:
     coord.request_stop(e)
   finally:
     coord.request_stop()
     coord.join(threads)

get_all_records('/path/to/train-0.tfrecords')

Et pour utiliser tf.parse_example (qui est plus rapide que tf.parse_single_example) vous devez d’abord grouper les exemples comme celui-ci:

batch = tf.train.batch([serialized_example], num_examples, capacity=num_examples)
parsed_examples = tf.parse_example(batch, feature_spec)

Malheureusement, de cette façon, vous auriez besoin de connaître le nombre d'exemples à l'avance.

12
sygi

Si vous avez besoin de lire toutes les données de TFRecord en même temps, vous pouvez écrire une solution plus simple en quelques lignes de code en utilisant tf_record_iterator :

Un itérateur qui lit les enregistrements d'un fichier TFRecords.

Pour ce faire, il vous suffit de:

  1. créer un exemple
  2. itérer sur les enregistrements de l'itérateur
  3. analyser chaque enregistrement et lire chaque fonctionnalité en fonction de son type

Voici un exemple expliquant comment lire chaque type.

example = tf.train.Example()
for record in tf.python_io.tf_record_iterator(<tfrecord_file>):
    example.ParseFromString(record)
    f = example.features.feature
    v1 = f['int64 feature'].int64_list.value[0]
    v2 = f['float feature'].float_list.value[0]
    v3 = f['bytes feature'].bytes_list.value[0]
    # for bytes you might want to represent them in a different way (based on what they were before saving)
    # something like `np.fromstring(f['img'].bytes_list.value[0], dtype=np.uint8
    # Now do something with your v1/v2/v3
11
Salvador Dali

Vous pouvez également utiliser tf.python_io.tf_record_iterator Pour itérer manuellement tous les exemples dans un TFRecord.

Je teste cela avec un code d'illustration ci-dessous:

import tensorflow as tf

X = [[1, 2],
     [3, 4],
     [5, 6]]


def _int_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def dump_tfrecord(data, out_file):
    writer = tf.python_io.TFRecordWriter(out_file)
    for x in data:
        example = tf.train.Example(
            features=tf.train.Features(feature={
                'x': _int_feature(x)
            })
        )
        writer.write(example.SerializeToString())
    writer.close()


def load_tfrecord(file_name):
    features = {'x': tf.FixedLenFeature([2], tf.int64)}
    data = []
    for s_example in tf.python_io.tf_record_iterator(file_name):
        example = tf.parse_single_example(s_example, features=features)
        data.append(tf.expand_dims(example['x'], 0))
    return tf.concat(0, data)


if __name__ == "__main__":
    dump_tfrecord(X, 'test_tfrecord')
    print('dump ok')
    data = load_tfrecord('test_tfrecord')

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        Y = sess.run([data])
        print(Y)

Bien sûr, vous devez utiliser votre propre spécification feature.

L'inconvénient est que je ne sais pas utiliser les multi-threads de cette façon. Cependant, la plupart du temps, nous lisons tous les exemples lorsque nous évaluons un ensemble de données de validation, qui n’est généralement pas très volumineux. Donc, je pense que l'efficacité peut ne pas être un goulot d'étranglement.

Et j'ai un autre problème lorsque je teste ce problème, c'est que je dois spécifier la longueur de la fonctionnalité. Au lieu de tf.FixedLenFeature([], tf.int64), je dois écrire tf.FixedLenFeature([2], tf.int64), sinon un InvalidArgumentError se produirait. Je ne sais pas comment éviter ça.

Python: 3.4
Tensorflow: 0.12.0

8
Shen Fei

Je ne sais pas si c'est toujours un sujet actif. J'aimerais partager la meilleure pratique que je connaisse jusqu'à présent, mais c'est une question il y a un an.

Dans le tensorflow, nous avons une méthode très utile pour résoudre le problème de ce type: lire ou parcourir l'ensemble des données d'entrée et générer une formation permettant de tester l'ensemble de données de manière aléatoire. 'tf.train.shuffle_batch' peut générer la base de données sur le flux d'entrée (comme reader.read ()) lorsque vous vous comportez. Comme par exemple, vous pouvez générer un ensemble de données set 1000 en fournissant une liste d'arguments semblable à celle-ci:

reader = tf.TFRecordReader()
_, serialized = reader.read(filename_queue)
features = tf.parse_single_example(
    serialized,
    features={
        'label': tf.FixedLenFeature([], tf.string),
        'image': tf.FixedLenFeature([], tf.string)
    }
)
record_image = tf.decode_raw(features['image'], tf.uint8)

image = tf.reshape(record_image, [500, 500, 1])
label = tf.cast(features['label'], tf.string)
min_after_dequeue = 10
batch_size = 1000
capacity = min_after_dequeue + 3 * batch_size
image_batch, label_batch = tf.train.shuffle_batch(
    [image, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue
)
3
user2538491

En outre, si vous ne pensez pas que "tf.train.shuffle_batch" est la solution dont vous avez besoin. Vous pouvez également essayer de combiner tf.TFRecordReader (). Read_up_to () et tf.parse_example (). Voici l'exemple pour votre référence:

def read_tfrecords(folder_name, bs):
    filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once(glob.glob(folder_name + "/*.tfrecords")))
    reader = tf.TFRecordReader()
    _, serialized = reader.read_up_to(filename_queue, bs)
    features = tf.parse_example(
        serialized,
        features={
            'label': tf.FixedLenFeature([], tf.string),
            'image': tf.FixedLenFeature([], tf.string)
        }
    )
    record_image = tf.decode_raw(features['image'], tf.uint8)
    image = tf.reshape(record_image, [-1, 250, 151, 1])
    label = tf.cast(features['label'], tf.string)
    return image, label
1
user2538491