web-dev-qa-db-fra.com

Comment modifier dynamiquement la taille du lot dans l'ensemble de données Tensorflow 2.0?

Dans TensorFlow 1.X, vous pouvez modifier la taille du lot de manière dynamique à l'aide d'un espace réservé. par exemple

dataset.batch(batch_size=tf.placeholder())
Voir l'exemple complet

Comment faites-vous cela dans TensorFlow 2.0?

J'ai essayé ce qui suit mais cela ne fonctionne pas.

import numpy as np
import tensorflow as tf


def new_gen_function():
    for i in range(100):
        yield np.ones(2).astype(np.float32)


batch_size = tf.Variable(5, trainable=False, dtype=tf.int64)
train_ds = tf.data.Dataset.from_generator(new_gen_function, output_types=(tf.float32)).batch(
    batch_size=batch_size)

for data in train_ds:
    print(data.shape[0])
    batch_size.assign(10)
    print(batch_size)

Production

5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
...
...

J'entraîne un modèle à l'aide d'une boucle d'entraînement personnalisée à l'aide de ruban dégradé. Comment puis-je atteindre cet objectif?

5
Himaprasoon

Évidemment, si vous utilisez .from_generator, vous pouvez regrouper manuellement les éléments, mais cela ne répond pas vraiment à votre question.

Les deux façons les plus simples auxquelles je peux penser sont d'inclure la taille du lot en tant que composant de l'ensemble de données, puis de créer des lots de la taille demandée:

import tensorflow as tf

batch_sizes = tf.data.Dataset.range(4)
ds = batch_sizes.map(lambda n: tf.random.normal(shape=[n,3]))

for item in ds:
  print(item.shape)
  print()
(0,3)
(1,3)
(2,3)
(3,3)

Ou, en vous appuyant sur la solution de @ PG-N, si vous avez besoin d'une version qui s'exécute totalement à l'intérieur d'un tf.function vous pouvez les emballer en utilisant tf.TensorArray:

import tensorflow as tf 

 class Batcher(tf.Module):
   def __init__(self, ds, batch_size=0):   
     self.it = iter(ds) 
     self._batch_size = tf.Variable(batch_size) 

   @property 
   def batch_size(self): 
     return self._batch_size

   @batch_size.setter  
   def batch_size(self, new_size): 
     self._batch_size.assign(new_size) 

   @tf.function 
   def __call__(self): 
     examples =tf.TensorArray(dtype=tf.int64, size=self.batch_size) 
     for i in tf.range(self.batch_size): 
       examples = examples.write(i, next(self.it)) 

     return examples.stack() 
ds = tf.data.Dataset.range(100)
B = Batcher(ds)
B.batch_size = 5
B().numpy()
array([0, 1, 2, 3, 4])
B.batch_size = 10
B().numpy()
array([ 5,  6,  7,  8,  9, 10, 11, 12, 13, 14])
B.batch_size = 3
B().numpy()
array([ 15,  16,  17])

Vous pourriez probablement faire quelque chose au milieu en utilisant tf.nest pour rendre cela général aux ensembles de données avec plus qu'un simple composant tensoriel.

En outre, en fonction de votre cas d'utilisation, des méthodes telles que group_by_window et bucket_by_sequence_length pourrait être utile. Ceux-ci font des lots de plusieurs tailles, ils pourraient être ce que vous recherchez, ou la mise en œuvre pourrait être un indice de votre problème.

1
mdaoust

D'après ce que je sais, vous devez instancier un nouvel itérateur de jeu de données pour que votre modification prenne effet. Cela nécessitera de modifier un peu pour sauter les échantillons déjà vus.

Voici ma solution la plus simple:

import numpy as np
import tensorflow as tf

def get_dataset(batch_size, num_samples_seen):
    return tf.data.Dataset.range(
        100
    ).skip(
        num_samples_seen
    ).batch(
        batch_size=batch_size
    )

def main():
    batch_size = 1
    num_samples_seen = 0

    train_ds = get_dataset(batch_size, num_samples_seen)

    ds_iterator = iter(train_ds)
    while True:
        try:
            data = next(ds_iterator)
        except StopIteration:
            print("End of iteration")
            break

        print(data)
        batch_size *= 2
        num_samples_seen += data.shape[0]
        ds_iterator = iter(get_dataset(batch_size, num_samples_seen))
        print("New batch size:", batch_size)

if __name__ == "__main__":
    main()

Comme vous pouvez le voir ici, vous devez instancier un nouvel ensemble de données (via un appel à get_dataset) et mettez à jour l'itérateur.

Je ne connais pas l'impact sur les performances d'une telle solution. Il existe peut-être une autre solution nécessitant d'instancier "simplement" une étape batch au lieu de l'ensemble de données.

1
AlexisBRENON