web-dev-qa-db-fra.com

Que font les traitements par lots, les répétitions et les mélanges avec le jeu de données TensorFlow?

J'apprends actuellement TensorFlow, mais je découvre une confusion dans ce code:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()

je sais d'abord que l'ensemble de données contiendra toutes les données, mais que shuffle (), repeat () et batch () font-ils sur l'ensemble de données? donnez-moi s'il vous plaît une explication avec un exemple

4
blue

Imaginez, vous avez un jeu de données: [1, 2, 3, 4, 5, 6], puis:

Comment fonctionne ds.shuffle ()

dataset.shuffle(buffer_size=3) allouera un tampon de taille 3 pour la sélection d'entrées aléatoires. Ce tampon sera connecté au jeu de données source . Nous pourrions l’image comme ceci:

Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ↓         ↓
[1,2,3] <= [4,5,6]

Supposons que l'entrée 2 a été prise dans le tampon aléatoire. L'espace libre est rempli par l'élément suivant du tampon source, à savoir 4:

2 <= [1,3,4] <= [5,6]

Nous continuons à lire jusqu'à ce qu'il ne reste plus rien:

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]      <= []
4 <= []      <= []

Comment fonctionne ds.repeat ()

Dès que toutes les entrées sont lues à partir du jeu de données et que vous essayez de lire l'élément suivant, le jeu de données génère une erreur . C'est là que ds.repeat() entre en jeu. Il va réinitialiser le jeu de données, le rendant à nouveau comme ceci:

[1,2,3] <= [4,5,6]

Qu'est-ce que ds.batch () va produire

La ds.batch() prendra les premières entrées batch_size et en fera un lot. Ainsi, une taille de lot de 3 pour notre exemple de données produira deux enregistrements de lot:

[2,1,5]
[3,6,4]

Comme nous avons un ds.repeat() avant le lot, la génération des données se poursuivra. Mais l'ordre des éléments sera différent, en raison de la ds.random(). Ce qui doit être pris en compte est que 6 ne sera jamais présent dans le premier lot, en raison de la taille du tampon aléatoire.

2
Vlad-HC

Les méthodes suivantes dans tf.Dataset:

  1. repeat( count=0 ) La méthode répète le jeu de données count nombre de fois.
  2. shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) La méthode mélange les échantillons dans le jeu de données. Le buffer_size correspond au nombre d'échantillons randomisés et renvoyés sous la forme tf.Dataset.
  3. batch(batch_size,drop_remainder=False) Crée des lots du jeu de données avec la taille de lot donnée sous la forme batch_size, qui correspond également à la longueur des lots.
0
Shubham Panchal