web-dev-qa-db-fra.com

Comment comprendre tf.get_collection () dans TensorFlow

Je suis confus par tf.get_collection() forme le docs , il dit que

Renvoie une liste de valeurs dans la collection avec le nom donné.

Et un exemple sur Internet est ici

from_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, from_scope)

Est-ce que cela signifie qu'il collecte des variables à partir de tf.GraphKeys.TRAINABLE_VARIABLES à from_scope?

Cependant, comment puis-je utiliser cette fonction si je veux obtenir des variables d'une autre étendue? Je vous remercie!

10
GoingMyWay

Une collection n'est rien d'autre qu'un ensemble nommé de valeurs.

Chaque valeur est un nœud du graphe de calcul.

Chaque nœud a son nom et le nom est composé par la concaténation des étendues, / et des valeurs telles que: preceding/scopes/in/that/way/value

get_collection, sans scope permet de récupérer toutes les valeurs de la collection sans appliquer d'opération de filtrage.

Lorsque le paramètre scope est présent, chaque élément de la collection est filtré et renvoyé uniquement si le nom du nœud commence par le scope spécifié.

7
nessuno

Comme décrit dans le doc de chaîne:

  • TRAINABLE_VARIABLES: le sous-ensemble d'objets Variable qui sera formé par un optimiseur.

et

portée: (Facultatif.) Une chaîne. Si elle est fournie, la liste résultante est filtrée pour inclure uniquement les éléments dont l'attribut name correspond à scope à l'aide de re.match. Les éléments sans attribut name ne sont jamais renvoyés si une étendue est fournie. Le choix de re.match signifie qu'un scope sans jetons spéciaux filtre par préfixe.

Il retournera donc la liste des variables pouvant être entraînées dans la portée donnée.

1
pfm