web-dev-qa-db-fra.com

Poids de copie Keras Tensorflow d'un modèle à un autre

À l'aide de Keras de Tensorflow 1.4.1, comment copier des poids d'un modèle à un autre?

Comme arrière-plan, j'essaie d'implémenter un réseau deep-q (DQN) pour les jeux Atari suite à la publication DQN par DeepMind. Ma compréhension est que la mise en œuvre utilise deux réseaux, Q et Q '. Les poids de Q sont entraînés en utilisant la descente de gradient, puis les poids sont copiés périodiquement dans Q '.

Voici comment je construis Q et Q ':

ACT_SIZE   = 4
LEARN_RATE = 0.0025
OBS_SIZE   = 128

def buildModel():
  model = tf.keras.models.Sequential()

  model.add(tf.keras.layers.Lambda(lambda x: x / 255.0, input_shape=OBS_SIZE))
  model.add(tf.keras.layers.Dense(128, activation="relu"))
  model.add(tf.keras.layers.Dense(128, activation="relu"))
  model.add(tf.keras.layers.Dense(ACT_SIZE, activation="linear"))
  opt = tf.keras.optimizers.RMSprop(lr=LEARN_RATE)

  model.compile(loss="mean_squared_error", optimizer=opt)

  return model

J'appelle ça deux fois pour obtenir Q et Q '.

J'ai une méthode updateTargetModel ci-dessous qui est ma tentative de copier des poids. Le code fonctionne bien, mais mon implémentation DQN globale échoue. J'essaie vraiment de vérifier s'il s'agit d'un moyen valide de copier des poids d'un réseau à un autre.

def updateTargetModel(model, targetModel):
  modelWeights       = model.trainable_weights
  targetModelWeights = targetModel.trainable_weights

  for i in range(len(targetModelWeights)):
    targetModelWeights[i].assign(modelWeights[i])

Il y a une autre question ici qui traite de l'enregistrement et du chargement des poids vers et depuis le disque ( problème de poids de copie Tensorflow ), mais il n'y a pas de réponse acceptée. Il y a aussi une question sur le chargement des poids à partir de couches individuelles ( Copie des poids d'une couche Conv2D à une autre ), mais je veux copier les poids du modèle entier.

9
avejidah

En fait, ce que vous avez fait est bien plus que la simple copie de poids. Vous avez fait ces deux modèles identiques tout le temps. Chaque fois que vous mettez à jour un modèle - le second est également mis à jour - car les deux modèles ont les mêmes variables weights.

Si vous voulez simplement copier des poids - la manière la plus simple est par cette commande:

target_model.set_weights(model.get_weights()) 
24
Marcin Możejko