J'entraîne un NN et j'aimerais enregistrer les poids du modèle à toutes les N époques pour une phase de prédiction. Je propose ce projet de code, il s'inspire de la réponse de @grovina ici . Pourriez-vous, s'il vous plaît, faire des suggestions? Merci d'avance.
from keras.callbacks import Callback
class WeightsSaver(Callback):
def __init__(self, model, N):
self.model = model
self.N = N
self.Epoch = 0
def on_batch_end(self, Epoch, logs={}):
if self.Epoch % self.N == 0:
name = 'weights%08d.h5' % self.Epoch
self.model.save_weights(name)
self.Epoch += 1
Ajoutez-le ensuite à l'appel d'ajustement: pour enregistrer des poids toutes les 5 époques:
model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])
Vous ne devriez pas avoir besoin de passer un modèle pour le rappel. Il a déjà accès au modèle via son super. Supprimez donc l'argument __init__(..., model, ...)
et self.model = model
. Vous devriez pouvoir accéder au modèle actuel via self.model
indépendamment. Vous l'enregistrez également à chaque fin de lot, ce qui n'est pas ce que vous voulez, vous voulez probablement que ce soit on_Epoch_end
.
Mais dans tous les cas, ce que vous faites peut se faire via naïf modelcheckpoint callback . Vous n'avez pas besoin d'en écrire un personnalisé. Vous pouvez l'utiliser comme suit;
mc = keras.callbacks.ModelCheckpoint('weights{Epoch:08d}.h5',
save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])
Vous devez implémenter sur on_Epoch_end plutôt implémenter on_batch_end. Et aussi passer le modèle comme argument pour __init__
est redondant.
from keras.callbacks import Callback
class WeightsSaver(Callback):
def __init__(self, N):
self.N = N
self.Epoch = 0
def on_Epoch_end(self, Epoch, logs={}):
if self.Epoch % self.N == 0:
name = 'weights%08d.h5' % self.Epoch
self.model.save_weights(name)
self.Epoch += 1