J'apprends actuellement TensorFlow, mais je découvre une confusion dans ce code:
dataset = dataset.shuffle(buffer_size = 10 * batch_size)
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
je sais d'abord que l'ensemble de données contiendra toutes les données, mais que shuffle (), repeat () et batch () font-ils sur l'ensemble de données? donnez-moi s'il vous plaît une explication avec un exemple
Imaginez, vous avez un jeu de données: [1, 2, 3, 4, 5, 6]
, puis:
Comment fonctionne ds.shuffle ()
dataset.shuffle(buffer_size=3)
allouera un tampon de taille 3 pour la sélection d'entrées aléatoires. Ce tampon sera connecté au jeu de données source . Nous pourrions l’image comme ceci:
Random buffer
|
| Source dataset where all other elements live
| |
↓ ↓
[1,2,3] <= [4,5,6]
Supposons que l'entrée 2
a été prise dans le tampon aléatoire. L'espace libre est rempli par l'élément suivant du tampon source, à savoir 4
:
2 <= [1,3,4] <= [5,6]
Nous continuons à lire jusqu'à ce qu'il ne reste plus rien:
1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6] <= []
6 <= [4] <= []
4 <= [] <= []
Comment fonctionne ds.repeat ()
Dès que toutes les entrées sont lues à partir du jeu de données et que vous essayez de lire l'élément suivant, le jeu de données génère une erreur . C'est là que ds.repeat()
entre en jeu. Il va réinitialiser le jeu de données, le rendant à nouveau comme ceci:
[1,2,3] <= [4,5,6]
Qu'est-ce que ds.batch () va produire
La ds.batch()
prendra les premières entrées batch_size
et en fera un lot. Ainsi, une taille de lot de 3 pour notre exemple de données produira deux enregistrements de lot:
[2,1,5]
[3,6,4]
Comme nous avons un ds.repeat()
avant le lot, la génération des données se poursuivra. Mais l'ordre des éléments sera différent, en raison de la ds.random()
. Ce qui doit être pris en compte est que 6
ne sera jamais présent dans le premier lot, en raison de la taille du tampon aléatoire.
Les méthodes suivantes dans tf.Dataset:
repeat( count=0 )
La méthode répète le jeu de données count
nombre de fois.shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)
La méthode mélange les échantillons dans le jeu de données. Le buffer_size
correspond au nombre d'échantillons randomisés et renvoyés sous la forme tf.Dataset
.batch(batch_size,drop_remainder=False)
Crée des lots du jeu de données avec la taille de lot donnée sous la forme batch_size
, qui correspond également à la longueur des lots.