web-dev-qa-db-fra.com

Comment charger uniquement des poids spécifiques sur Keras

J'ai un modèle formé que j'ai exporté les poids et que je souhaite charger partiellement dans un autre modèle. Mon modèle est construit en Keras en utilisant TensorFlow comme backend.

En ce moment, je fais comme suit:

model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=input_shape, trainable=False))
model.add(Activation('relu', trainable=False))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3), trainable=False))
model.add(Activation('relu', trainable=False))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3), trainable=True))
model.add(Activation('relu', trainable=True))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])


model.load_weights("image_500.h5")
model.pop()
model.pop()
model.pop()
model.pop()
model.pop()
model.pop()


model.add(Conv2D(1, (6, 6),strides=(1, 1), trainable=True))
model.add(Activation('relu', trainable=True))

model.compile(loss='binary_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

Je suis sûr que c'est une façon terrible de le faire, même si cela fonctionne.

Comment charger uniquement les 9 premières couches?

16
BernardoGO

Si vos 9 premiers calques sont nommés de manière cohérente entre votre modèle formé d'origine et le nouveau modèle, vous pouvez utiliser model.load_weights() avec by_name=True. Cela mettra à jour les poids uniquement dans les couches de votre nouveau modèle qui ont une couche de même nom trouvée dans le modèle formé d'origine.

Le nom du calque peut être spécifié avec le mot clé name, par exemple:

model.add(Dense(8, activation='relu',name='dens_1'))
25
dhinckley

Cet appel:

weights_list = model.get_weights()

renverra une liste de tous les tenseurs de poids du modèle, sous forme de tableaux Numpy.

Tout ce que vous avez à faire ensuite est de parcourir cette liste et d'appliquer:

for i, weights in enumerate(weights_list[0:9]):
    model.layers[i].set_weights(weights)

model.layers est une liste aplatie des couches composant le modèle. Dans ce cas, vous rechargez les poids des 9 premières couches.

Plus d'informations sont disponibles ici:

https://keras.io/layers/about-keras-layers/

https://keras.io/models/about-keras-models/

18
Philippe Remy