Keras 2.x a tué un tas de mesures utiles que je dois utiliser, j'ai donc copié les fonctions de l'ancien fichier metrics.py dans mon code, puis je les ai incluses comme suit.
def precision(y_true, y_pred): #taken from old keras source code
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
precision = true_positives / (predicted_positives + K.epsilon())
return precision
def recall(y_true, y_pred): #taken from old keras source code
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
recall = true_positives / (possible_positives + K.epsilon())
return recall
...
model.compile(loss='categorical_crossentropy', optimizer='adam',
metrics=['accuracy', precision, recall])
et cela se traduit par
ValueError: Unknown metric function:precision
Qu'est-ce que je fais mal? Je ne vois rien que je fasse de mal selon la documentation de Keras.
éditer:
Voici le Traceback complet:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Library/Python/2.7/site-packages/keras/models.py", line 274, in
load_model
sample_weight_mode=sample_weight_mode)
File "/Library/Python/2.7/site-packages/keras/models.py", line 824, in
compile
**kwargs)
File "/Library/Python/2.7/site-packages/keras/engine/training.py", line
934, in compile
handle_metrics(output_metrics)
File "/Library/Python/2.7/site-packages/keras/engine/training.py", line
901, in handle_metrics
metric_fn = metrics_module.get(metric)
File "/Library/Python/2.7/site-packages/keras/metrics.py", line 75, in get
return deserialize(str(identifier))
File "/Library/Python/2.7/site-packages/keras/metrics.py", line 67, in
deserialize
printable_module_name='metric function')
File "/Library/Python/2.7/site-packages/keras/utils/generic_utils.py",
line 164, in deserialize_keras_object
':' + function_name)
ValueError: Unknown metric function:precision
<FATAL> : Failed to load Keras model from file:
model.h5
***> abort program execution
Traceback (most recent call last):
File "classification.py", line 84, in <module>
'H:!V:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
#:VarTransform=D,G
TypeError: none of the 3 overloaded methods succeeded. Full details:
TMVA::MethodBase* TMVA::Factory::BookMethod(TMVA::DataLoader* loader,
TString theMethodName, TString methodTitle, TString theOption = "") =>
could not convert argument 2
TMVA::MethodBase* TMVA::Factory::BookMethod(TMVA::DataLoader* loader,
TMVA::Types::EMVA theMethod, TString methodTitle, TString theOption = "") =>
FATAL error (C++ exception of type runtime_error)
TMVA::MethodBase* TMVA::Factory::BookMethod(TMVA::DataLoader*,
TMVA::Types::EMVA, TString, TString, TMVA::Types::EMVA, TString) =>
takes at least 6 arguments (4 given)
De la traceback, il semble que le problème se produit lorsque vous essayez de charger le modèle enregistré:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Library/Python/2.7/site-packages/keras/models.py", line 274, in
load_model
sample_weight_mode=sample_weight_mode)
...
ValueError: Unknown metric function:precision
<FATAL> : Failed to load Keras model from file:
model.h5
Jetez un œil à ce problème: https://github.com/keras-team/keras/issues/10104
Vous devez ajouter vos objets personnalisés lors du chargement du modèle. Par exemple:
dependencies = {
'auc_roc': auc_roc
}
model = keras.models.load_model(self.output_directory + 'best_model.hdf5', custom_objects=dependencies)
Ma suggestion serait d'implémenter vos métriques dans le rappel Keras.
Car:
Il peut réaliser la même chose que metrics
.
Il peut également vous fournir une stratégie d'économie de modèle.
classe Checkpoint (keras.callbacks.Callback):
def __init__(self, test_data, filename):
self.test_data = test_data
self.filename = filename
def on_train_begin(self, logs=None):
self.pre = [0.]
self.rec = [0.]
print('Test on %s begins' % self.filename)
def on_train_end(self, logs={}):
print('Best Precison: %s' % max(self.pre))
print('Best Recall: %s' % max(self.rec))
return
def on_Epoch_end(self, Epoch, logs={}):
x, y = self.test_data
self.pre.append(precision(x, y))
self.rec.append(recall(x, y))
# print your precision or recall as you want
print(...)
# Save your model when a better trained model was found
if pre > max(self.pre):
self.model.save(self.filename, overwrite=True)
print('Higher precision found. Save as %s' % self.filename)
return
après cela, vous pouvez ajouter votre rappel à votre:
checkpoint = Checkpoint((x_test, y_test), 'precison.h5')
model.compile(loss='categorical_crossentropy', optimizer='adam', callbacks=[checkpoint])
J'ai testé votre code dans Python 3.6.5
, TensorFlow==1.9
et Keras==2.2.2
et cela a fonctionné. Je pense que l'erreur pourrait être due à Python 2 utilisation.
import numpy as np
import tensorflow as tf
import keras
import keras.backend as K
from keras.layers import Dense
from keras.models import Sequential, Input, Model
from sklearn import datasets
print(f"TF version: {tf.__version__}, Keras version: {keras.__version__}\n")
# dummy dataset
iris = datasets.load_iris()
x, y_ = iris.data, iris.target
def one_hot(v): return np.eye(len(np.unique(v)))[v]
y = one_hot(y_)
# model
inp = Input(shape=(4,))
dense = Dense(8, activation='relu')(inp)
dense = Dense(16, activation='relu')(dense)
dense = Dense(3, activation='softmax')(dense)
model = Model(inputs=inp, outputs=dense)
# custom metrics
def precision(y_true, y_pred): #taken from old keras source code
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
precision = true_positives / (predicted_positives + K.epsilon())
return precision
def recall(y_true, y_pred): #taken from old keras source code
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
recall = true_positives / (possible_positives + K.epsilon())
return recall
# training
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', precision, recall])
model.fit(x=x, y=y, batch_size=8, epochs=15)
Production:
TF version: 1.9.0, Keras version: 2.2.2
Epoch 1/15
150/150 [==============================] - 0s 2ms/step - loss: 1.2098 - acc: 0.2600 - precision: 0.0000e+00 - recall: 0.0000e+00
Epoch 2/15
150/150 [==============================] - 0s 135us/step - loss: 1.1036 - acc: 0.4267 - precision: 0.0000e+00 - recall: 0.0000e+00
Epoch 3/15
150/150 [==============================] - 0s 132us/step - loss: 1.0391 - acc: 0.5733 - precision: 0.0000e+00 - recall: 0.0000e+00
Epoch 4/15
150/150 [==============================] - 0s 133us/step - loss: 0.9924 - acc: 0.6533 - precision: 0.0000e+00 - recall: 0.0000e+00
Epoch 5/15
150/150 [==============================] - 0s 108us/step - loss: 0.9379 - acc: 0.6667 - precision: 0.0000e+00 - recall: 0.0000e+00
Epoch 6/15
150/150 [==============================] - 0s 134us/step - loss: 0.8802 - acc: 0.6667 - precision: 0.0533 - recall: 0.0067
Epoch 7/15
150/150 [==============================] - 0s 167us/step - loss: 0.8297 - acc: 0.7867 - precision: 0.4133 - recall: 0.0800
Epoch 8/15
150/150 [==============================] - 0s 138us/step - loss: 0.7743 - acc: 0.8200 - precision: 0.9467 - recall: 0.3667
Epoch 9/15
150/150 [==============================] - 0s 161us/step - loss: 0.7232 - acc: 0.7467 - precision: 1.0000 - recall: 0.5667
Epoch 10/15
150/150 [==============================] - 0s 134us/step - loss: 0.6751 - acc: 0.8000 - precision: 0.9733 - recall: 0.6333
Epoch 11/15
150/150 [==============================] - 0s 134us/step - loss: 0.6310 - acc: 0.8867 - precision: 0.9924 - recall: 0.6400
Epoch 12/15
150/150 [==============================] - 0s 131us/step - loss: 0.5844 - acc: 0.8867 - precision: 0.9759 - recall: 0.6600
Epoch 13/15
150/150 [==============================] - 0s 111us/step - loss: 0.5511 - acc: 0.9133 - precision: 0.9759 - recall: 0.6533
Epoch 14/15
150/150 [==============================] - 0s 134us/step - loss: 0.5176 - acc: 0.9000 - precision: 0.9403 - recall: 0.6733
Epoch 15/15
150/150 [==============================] - 0s 134us/step - loss: 0.4899 - acc: 0.8667 - precision: 0.8877 - recall: 0.6733