J'ai créé une variable trainable dans une portée. Plus tard, j'ai entré la même portée, défini la portée sur reuse_variables
, et utilisé get_variable
pour récupérer la même variable. Cependant, je ne peux pas définir la propriété trainable de la variable sur False
. Ma get_variable
la ligne est comme:
weight_var = tf.get_variable('weights', trainable = False)
Mais la variable 'weights'
est toujours dans la sortie de tf.trainable_variables
.
Puis-je définir l'indicateur trainable
d'une variable partagée sur False
en utilisant get_variable
?
La raison pour laquelle je veux le faire est que j'essaie de réutiliser les filtres de bas niveau pré-formés à partir de VGG net dans mon modèle, et je veux construire le graphique comme avant, récupérer la variable de poids et attribuer des valeurs de filtre VGG à la variable de poids, puis maintenez-les fixes pendant l'étape de formation suivante.
Après avoir regardé la documentation et le code, j'étais pas en mesure de trouver un moyen de supprimer une variable du TRAINABLE_VARIABLES
.
tf.get_variable('weights', trainable=True)
est appelée, la variable est ajoutée à la liste de TRAINABLE_VARIABLES
.tf.get_variable('weights', trainable=False)
, vous obtenez la même variable mais l'argument trainable=False
N'a aucun effet car la variable est déjà présente dans la liste de TRAINABLE_VARIABLES
(Et il y a aucun moyen de le supprimer de là)Lorsque vous appelez la méthode minimize
de l'optimiseur (voir doc. ), vous pouvez passer un var_list=[...]
Comme argument avec les variables que vous souhaitez optimiser.
Par exemple, si vous souhaitez figer toutes les couches de VGG à l'exception des deux dernières, vous pouvez transmettre les poids des deux dernières couches dans var_list
.
Vous pouvez utiliser une tf.train.Saver()
pour enregistrer les variables et les restaurer plus tard (voir ce tutoriel ).
saver.save(sess, "/path/to/dir/model.ckpt")
.saver.restore(sess, "/path/to/dir/model.ckpt")
.Facultativement, vous pouvez décider de ne sauvegarder que certaines des variables dans votre fichier de point de contrôle. Voir doc pour plus d'informations.
Lorsque vous souhaitez former ou optimiser uniquement certaines couches d'un réseau pré-formé, c'est ce que vous devez savoir.
La méthode minimize
de TensorFlow prend un argument facultatif var_list
, une liste de variables à ajuster par rétropropagation.
Si vous ne spécifiez pas var_list
, toute variable TF du graphique peut être ajustée par l'optimiseur. Lorsque vous spécifiez certaines variables dans var_list
, TF maintient toutes les autres variables constantes.
Voici un exemple de script que jonbruner et son collaborateur ont utilisé.
tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)
Cela recherche toutes les variables définies précédemment qui ont "g_" dans le nom de la variable, les place dans une liste et exécute l'optimiseur ADAM sur elles.
Vous pouvez trouver les réponses correspondantes ici sur Quora
Afin de supprimer une variable de la liste des variables entraînables, vous pouvez d'abord accéder à la collection via: trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
Là, trainable_collection
Contient une référence à la collection de variables entraînables. Si vous pop des éléments de cette liste, en faisant par exemple trainable_collection.pop(0)
, vous allez supprimer la variable correspondante des variables entraînables, et donc cette variable ne sera pas entraînée.
Bien que cela fonctionne avec pop
, j'ai encore du mal à trouver un moyen d'utiliser correctement remove
avec l'argument correct, donc nous ne dépendons pas de l'index des variables.
EDIT: Étant donné que vous avez le nom des variables dans le graphique (vous pouvez l'obtenir en inspectant le protobuf du graphique ou, ce qui est plus facile, en utilisant Tensorboard), vous pouvez l'utiliser pour parcourir la liste des variables entraînables, puis supprimez les variables de la collection entraînable. Exemple: disons que je veux que les variables avec les noms "batch_normalization/gamma:0"
Et "batch_normalization/beta:0"
PAS soient entraînées, mais elles sont déjà ajoutées à la collection TRAINABLE_VARIABLES
. Ce que je peux faire, c'est: `
#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
#uses the attribute 'name' of the variable
if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
variables_to_remove.append(vari)
for rem in variables_to_remove:
trainable_collection.remove(rem)
`Cela supprimera avec succès les deux variables de la collection, et elles ne seront plus entraînées.
Vous pouvez utiliser tf.get_collection_ref pour obtenir la référence de la collection plutôt que tf.get_collection