Réglage
Comme déjà mentionné dans le titre, j'ai eu un problème avec ma fonction de perte personnalisée lors de la tentative de chargement du modèle enregistré. Ma perte se présente comme suit:
def weighted_cross_entropy(weights):
weights = K.variable(weights)
def loss(y_true, y_pred):
y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon())
loss = y_true * K.log(y_pred) * weights
loss = -K.sum(loss, -1)
return loss
return loss
weighted_loss = weighted_cross_entropy([0.1,0.9])
J'ai donc utilisé le weighted_loss
fonctionne comme fonction de perte et tout fonctionnait bien. Une fois la formation terminée, j'enregistre le modèle sous .h5
fichier avec la norme model.save
fonction de l'API keras.
Problème
Lorsque j'essaie de charger le modèle via
model = load_model(path,custom_objects={"weighted_loss":weighted_loss})
Je reçois un ValueError
me disant que la perte est inconnue.
Erreur
Le message d'erreur se présente comme suit:
File "...\predict.py", line 29, in my_script
"weighted_loss": weighted_loss})
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 419, in load_model
model = _deserialize_model(f, custom_objects, compile)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 312, in _deserialize_model
sample_weight_mode=sample_weight_mode)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\training.py", line 139, in compile
loss_function = losses.get(loss)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 133, in get
return deserialize(identifier)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 114, in deserialize
printable_module_name='loss function')
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\utils\generic_utils.py", line 165, in deserialize_keras_object
':' + function_name)
ValueError: Unknown loss function:loss
Questions
Comment puis-je résoudre ce problème? Est-il possible que la raison en soit ma définition de perte enveloppée? Donc, keras
ne sait pas comment gérer la variable weights
?
Le nom de votre fonction de perte est loss
(c'est-à-dire def loss(y_true, y_pred):
). Par conséquent, lors du chargement du modèle, vous devez spécifier 'loss'
comme son nom:
model = load_model(path, custom_objects={'loss': weighted_loss})