Autant que je sache, Variable
est l'opération par défaut pour la création d'une variable, et get_variable
est principalement utilisé pour le partage de poids.
D'un côté, certaines personnes suggèrent d'utiliser get_variable
au lieu de l'opération primitive Variable
chaque fois que vous avez besoin d'une variable. D'autre part, je ne vois qu'une utilisation de get_variable
dans les documents officiels et les démos de TensorFlow.
Je souhaite donc connaître quelques règles de base sur la manière d’utiliser correctement ces deux mécanismes. Existe-t-il des principes "standard"?
Je recommanderais de toujours utiliser tf.get_variable(...)
- il sera plus facile de refactoriser votre code si vous devez partager des variables à tout moment, par exemple dans un paramètre multi-gpu (voir l'exemple CIFAR multi-gpu). Il n'y a pas d'inconvénient à cela.
tf.Variable
pur est de niveau inférieur; tf.get_variable()
n’existait pas à un moment donné, de sorte que certains codes utilisent toujours la méthode de bas niveau.
tf.Variable est une classe et il existe plusieurs façons de créer tf.Variable, y compris tf.Variable .__ init__ et tf.get_variable.
tf.Variable .__ init__: crée une nouvelle variable avec initial_value.
W = tf.Variable(<initial-value>, name=<optional-name>)
tf.get_variable: obtient une variable existante avec ces paramètres ou en crée une nouvelle. Vous pouvez également utiliser l'initialiseur.
W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
regularizer=None, trainable=True, collections=None)
Il est très utile d'utiliser des initialiseurs tels que xavier_initializer:
W = tf.get_variable("W", shape=[784, 256],
initializer=tf.contrib.layers.xavier_initializer())
Plus d'informations sur https://www.tensorflow.org/versions/r0.8/api_docs/python/state_ops.html#Variable .
Je peux trouver deux différences principales entre l'un et l'autre:
La première est que tf.Variable
créera toujours une nouvelle variable, si tf.get_variable
obtient du graphe une variable existante avec ces paramètres, et si elle n'existe pas, elle en crée une nouvelle.
tf.Variable
nécessite qu'une valeur initiale soit spécifiée.
Il est important de préciser que la fonction tf.get_variable
préfixe le nom avec la portée de la variable actuelle pour effectuer des contrôles de réutilisation. Par exemple:
with tf.variable_scope("one"):
a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
c = tf.get_variable("v", [1]) #c.name == "one/v:0"
with tf.variable_scope("two"):
d = tf.get_variable("v", [1]) #d.name == "two/v:0"
e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"
assert(a is c) #Assertion is true, they refer to the same object.
assert(a is d) #AssertionError: they are different objects
assert(d is e) #AssertionError: they are different objects
La dernière erreur d'assertion est intéressante: deux variables du même nom sous la même portée sont supposées être la même variable. Mais si vous testez les noms des variables d
et e
, vous réaliserez que Tensorflow a modifié le nom de la variable e
:
d.name #d.name == "two/v:0"
e.name #e.name == "two/v_1:0"
Une autre différence réside dans le fait que l’un est dans la collection ('variable_store',)
mais que l’autre ne l’est pas.
S'il vous plaît voir la source code :
def _get_default_variable_store():
store = ops.get_collection(_VARSTORE_KEY)
if store:
return store[0]
store = _VariableStore()
ops.add_to_collection(_VARSTORE_KEY, store)
return store
Laissez-moi illustrer cela:
import tensorflow as tf
from tensorflow.python.framework import ops
embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="Word_embeddings_1", dtype=tf.float32)
embedding_2 = tf.get_variable("Word_embeddings_2", shape=[30522, 1024])
graph = tf.get_default_graph()
collections = graph.collections
for c in collections:
stores = ops.get_collection(c)
print('collection %s: ' % str(c))
for k, store in enumerate(stores):
try:
print('\t%d: %s' % (k, str(store._vars)))
except:
print('\t%d: %s' % (k, str(store)))
print('')
Le résultat:
collection ('__variable_store',): 0: {'Word_embeddings_2': <tf.Variable 'Word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}