web-dev-qa-db-fra.com

L'optimiseur d'Adam se détraque après 200 000 lots, la perte d'entraînement augmente

J'ai vu un comportement très étrange lors de la formation d'un réseau, où après quelques 100 000 itérations (8 à 10 heures) d'apprentissage correct, tout se brise et la perte de formation augmente :

Loss explodes

Les données de formation elles-mêmes sont randomisées et réparties sur de nombreux .tfrecord fichiers contenant 1000 exemples chacun, puis mélangés à nouveau dans l'étape de saisie et regroupés dans 200 exemples.

L'arrière-plan

Je conçois un réseau qui effectue quatre tâches de régression différentes en même temps, par ex. déterminer la probabilité qu'un objet apparaisse dans l'image et déterminer simultanément son orientation. Le réseau commence avec quelques couches convolutives, certaines avec des connexions résiduelles, puis se ramifie en quatre segments entièrement connectés.

Étant donné que la première régression entraîne une probabilité, j'utilise l'entropie croisée pour la perte, tandis que les autres utilisent la distance L2 classique. Cependant, en raison de leur nature, la perte de probabilité est de l'ordre de 0..1, alors que les pertes d'orientation peuvent être beaucoup plus importantes, disons 0..10. J'ai déjà normalisé les valeurs d'entrée et de sortie et j'utilise l'écrêtage

normalized = tf.clip_by_average_norm(inferred.sin_cos, clip_norm=2.)

dans les cas où les choses peuvent devenir vraiment mauvaises.

J'ai (avec succès) utilisé l'optimiseur Adam pour optimiser le tenseur contenant toutes les pertes distinctes (plutôt que reduce_suming them), comme ceci:

reg_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
loss = tf.pack([loss_probability, sin_cos_mse, magnitude_mse, pos_mse, reg_loss])

optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                   epsilon=self.params.adam_epsilon)
op_minimize = optimizer.minimize(loss, global_step=global_step)

Afin d'afficher les résultats dans TensorBoard, je fais ensuite

loss_sum = tf.reduce_sum(loss)

pour un résumé scalaire.

Adam est réglé sur le taux d'apprentissage 1e-4 et epsilon 1e-4 (Je vois le même comportement avec la valeur par défaut pour epislon et ça casse encore plus vite quand je garde le taux d'apprentissage sur 1e-3). La régularisation n'a pas non plus d'influence sur celui-ci, elle fait ce genre de cohérence à un moment donné.

Je dois également ajouter que l'arrêt de la formation et le redémarrage à partir du dernier point de contrôle - ce qui implique que les fichiers d'entrée de la formation sont également mélangés à nouveau - entraînent le même comportement. La formation semble toujours se comporter de la même manière à ce stade.

27
sunside

Oui. C'est un problème connu d'Adam.

Les équations pour Adam sont

t <- t + 1
lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)

m_t <- beta1 * m_{t-1} + (1 - beta1) * g
v_t <- beta2 * v_{t-1} + (1 - beta2) * g * g
variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)

m est une moyenne mobile exponentielle du gradient moyen et v est une moyenne mobile exponentielle des carrés des gradients. Le problème est que lorsque vous vous entraînez depuis longtemps et que vous êtes proche de l'optimal, v peut devenir très petit. Si alors tout d'un coup les gradients recommencent à augmenter, ils seront divisés par un très petit nombre et exploseront.

Par défaut beta1=0.9 et beta2=0.999. Ainsi, m change beaucoup plus rapidement que v. Ainsi, m peut recommencer à être grand tandis que v est encore petit et ne peut pas rattraper son retard.

Pour remédier à ce problème, vous pouvez augmenter epsilon qui est 10-8 par défaut. Ainsi, vous arrêtez le problème de diviser presque par 0. Selon votre réseau, une valeur de epsilon dans 0.1, 0.01, ou 0.001 pourrait être bon.

42
patapouf_ai

Oui, cela pourrait être une sorte de cas de nombres/équations instables super compliqués, mais la plupart des certitudes, votre taux d'entraînement est tout simplement trop élevé car votre perte diminue rapidement jusqu'à 25K et oscille ensuite beaucoup au même niveau. Essayez de la diminuer d'un facteur 0,1 et voyez ce qui se passe. Vous devriez pouvoir atteindre une valeur de perte encore plus faible.

Continuez à explorer! :)

3
Simanas