web-dev-qa-db-fra.com

Sauver meilleur modèle à keras

J'utilise le code suivant lors de la formation d'un modèle dans keras

from keras.callbacks import EarlyStopping

model = Sequential()
model.add(Dense(100, activation='relu', input_shape = input_shape))
model.add(Dense(1))

model_2.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])


model.fit(X, y, epochs=15, validation_split=0.4, callbacks=[early_stopping_monitor], verbose=False)

model.predict(X_test)

mais récemment, je souhaitais enregistrer le modèle le mieux entraîné, car les données sur lesquelles je m'entraîne donnent de nombreux pics dans le graphique "high val_loss vs epochs", et je souhaite utiliser le meilleur modèle possible à partir du modèle.

Y at-il une méthode ou une fonction pour aider avec ça?

17
dJOKER_dUMMY

EarlyStopping et ModelCheckpoint est ce dont vous avez besoin dans la documentation de Keras.

Vous devriez définir save_best_only=True dans ModelCheckpoint. Si d'autres ajustements sont nécessaires, ils sont triviaux.

Juste pour vous aider davantage, vous pouvez voir un usage ici sur Kaggle .


Ajouter le code ici au cas où le lien exemple Kaggle ci-dessus ne serait pas disponible:

model = getModel()
model.summary()

batch_size = 32

earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save = ModelCheckpoint('.mdl_wts.hdf5', save_best_only=True, monitor='val_loss', mode='min')
reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=7, verbose=1, epsilon=1e-4, mode='min')

model.fit(Xtr_more, Ytr_more, batch_size=batch_size, epochs=50, verbose=0, callbacks=[earlyStopping, mcp_save, reduce_lr_loss], validation_split=0.25)
25
Shridhar R Kulkarni

J'imagine model_2.compile était une faute de frappe. Cela devrait aider si vous voulez enregistrer le meilleur modèle avec les val_losses -

checkpoint = ModelCheckpoint('model-{Epoch:03d}-{acc:03f}-{val_acc:03f}.h5', verbose=1, monitor='val_loss',save_best_only=True, mode='auto')  

model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])

model.fit(X, y, epochs=15, validation_split=0.4, callbacks=[checkpoint], verbose=False)
5
Vivek