J'ai essayé d'implémenter soft-max avec le code suivant (out_vec
Est un vecteur numpy
de floats):
numerator = np.exp(out_vec)
denominator = np.sum(np.exp(out_vec))
out_vec = numerator/denominator
Cependant, j'ai eu une erreur de débordement à cause de np.exp(out_vec)
. Par conséquent, j'ai vérifié (manuellement) quelle est la limite supérieure de np.exp()
et trouve que np.exp(709)
est un nombre, mais np.exp(710)
est considéré comme étant np.inf
. Ainsi, pour éviter l'erreur de débordement, j'ai modifié mon code comme suit:
out_vec[out_vec > 709] = 709 #prevent np.exp overflow
numerator = np.exp(out_vec)
denominator = np.sum(np.exp(out_vec))
out_vec = numerator/denominator
Maintenant, je reçois une erreur différente:
RuntimeWarning: invalid value encountered in greater out_vec[out_vec > 709] = 709
Quel est le problème avec la ligne que j'ai ajoutée? J'ai recherché cette erreur spécifique et tout ce que j'ai trouvé est le conseil des gens sur la façon de l'ignorer. Ignorer simplement l'erreur ne m'aidera pas, car chaque fois que mon code rencontre cette erreur, il ne donne pas les résultats habituels.
Votre problème est dû aux éléments NaN
ou Inf
de votre out_vec
tableau. Vous pouvez utiliser le code suivant pour éviter ce problème:
if np.isnan(np.sum(out_vec)):
out_vec = out_vec[~numpy.isnan(out_vec)] # just remove nan elements from vector
out_vec[out_vec > 709] = 709
...
ou vous pouvez utiliser le code suivant pour laisser les valeurs NaN
dans votre tableau:
out_vec[ np.array([e > 709 if ~np.isnan(e) else False for e in out_vec], dtype=bool) ] = 709
Dans mon cas, l'avertissement n'apparaissait pas lorsque vous appelez cela avant la comparaison (les valeurs de NaN étaient comparées)
np.warnings.filterwarnings('ignore')
Le meilleur moyen à l’OMI serait d’utiliser une implémentation plus stable numériquement de la somme des exponentielles.
from scipy.misc import logsumexp
out_vec = np.exp(out_vec - logsumexp(out_vec))