En quoi train_on_batch () est différent de fit ()? Dans quels cas devrions-nous utiliser train_on_batch ()?
Je crois que vous voulez comparer train_on_batch
avec fit
(et des variations comme fit_generator
), puisque train
n'est pas une fonction API couramment disponible pour Keras.
Pour cette question, il s'agit d'un réponse simple de l'auteur principal :
Avec fit_generator, vous pouvez également utiliser un générateur pour les données de validation. En général, je recommanderais d'utiliser fit_generator, mais utiliser train_on_batch fonctionne aussi très bien. Ces méthodes n'existent que dans un souci de commodité, mais dans des cas d'utilisation différents, il n'y a pas de méthode "correcte".
train_on_batch
vous permet de mettre à jour expressément les poids en fonction d'un ensemble d'échantillons que vous fournissez, sans tenir compte de la taille d'un lot fixe. Vous utiliseriez ceci dans les cas où c'est ce que vous voulez: former sur une collection explicite d'échantillons. Vous pouvez utiliser cette approche pour conserver votre propre itération sur plusieurs lots d’un ensemble d’entraînement traditionnel, mais en autorisant fit
ou fit_generator
itérer des lots pour vous est probablement plus simple.
Un cas où il pourrait être agréable d’utiliser train_on_batch
sert à mettre à jour un modèle pré-formé sur un nouveau lot d'échantillons. Supposons que vous ayez déjà formé et déployé un modèle et qu'un peu plus tard, vous ayez reçu un nouvel ensemble d'échantillons d'apprentissage jamais utilisé auparavant. Vous pouvez utiliser train_on_batch
pour mettre à jour directement le modèle existant uniquement sur ces échantillons. Cela peut aussi être fait par d’autres méthodes, mais il est plutôt explicite d’utiliser train_on_batch
pour ce cas.
En dehors de cas exceptionnels comme celui-ci (que vous ayez une raison pédagogique de conserver votre propre curseur sur différents lots de formation ou que vous utilisiez un type de mise à jour de formation semi-en ligne sur un lot spécial), il est probablement préférable de toujours utiliser fit
(pour les données qui tiennent dans la mémoire) ou fit_generator
(pour la transmission de lots de données en tant que générateur).
train_on_batch()
vous permet de mieux contrôler l'état du LSTM. Par exemple, lorsque vous utilisez un LSTM dynamique et qu'il est nécessaire de contrôler les appels à model.reset_states()
, _. Vous pouvez avoir des données multi-séries et avoir besoin de réinitialiser l’état après chaque série, comme vous pouvez le faire avec train_on_batch()
, mais si vous utilisiez .fit()
, le réseau serait entraîné sur tous les série de données sans réinitialiser l'état. Il n’ya ni bon ni mauvais, cela dépend des données que vous utilisez et de la manière dont vous voulez que le réseau se comporte.
Train_on_batch verra également une augmentation des performances par rapport à Fit and Fit Generator si vous utilisez de grands ensembles de données et que vous n’avez pas de données facilement sérialisables (telles que des tableaux numérotés de haut rang), à écrire dans tfrecords.
Dans ce cas, vous pouvez enregistrer les tableaux sous forme de fichiers numpy et en charger de plus petits sous-ensembles (traina.npy, trainb.npy, etc.) en mémoire, lorsque l'ensemble ne rentre pas dans la mémoire. Vous pouvez ensuite utiliser tf.data.Dataset.from_tensor_slices puis utiliser train_on_batch avec votre sous-jeu de données, puis charger un autre jeu de données et appeler à nouveau train par lot, etc. de votre jeu de données entraîne votre modèle. Vous pouvez ensuite définir vos propres époques, tailles de lot, etc. avec de simples boucles et fonctions à récupérer dans votre jeu de données.