web-dev-qa-db-fra.com

Quel est le but de tf.app.flags dans TensorFlow?

Je lis des exemples de codes dans Tensorflow, j'ai trouvé le code suivant

_flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.')
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
flags.DEFINE_integer('batch_size', 100, 'Batch size.  '
                 'Must divide evenly into the dataset sizes.')
flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
                 'for unit testing.')
_

en _tensorflow/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py_

Mais je ne trouve aucune documentation sur cette utilisation de _tf.app.flags_.

Et j’ai trouvé que la mise en oeuvre de ces drapeaux se trouve dans tensorflow/tensorflow/python/platform/default/_flags.py

Évidemment, ce _tf.app.flags_ est en quelque sorte utilisé pour configurer un réseau, alors pourquoi n'est-il pas dans la documentation de l'API? Quelqu'un peut-il expliquer ce qui se passe ici?

101
flyaway1217

Le module _tf.app.flags_ est actuellement un wrapper mince autour de python-gflags, donc la documentation de ce projet est la meilleure ressource pour son utilisation argparse , qui implémente un sous-ensemble de la fonctionnalité dans python-gflags .

Notez que ce module est actuellement fourni pour faciliter l'écriture d'applications de démonstration et ne fait pas techniquement partie de l'API publique. Il peut donc être modifié à l'avenir.

Nous vous recommandons de mettre en œuvre votre propre analyse des indicateurs à l'aide de argparse ou de la bibliothèque de votre choix.

EDIT: Le module _tf.app.flags_ n'est pas implémenté à l'aide de _python-gflags_, mais il utilise une API similaire.

101
mrry

Le module tf.app.flags est une fonctionnalité fournie par Tensorflow pour implémenter des indicateurs de ligne de commande pour votre programme Tensorflow. Par exemple, le code que vous avez rencontré ferait ce qui suit:

flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')

Le premier paramètre définit le nom de l'indicateur, tandis que le second définit la valeur par défaut si l'indicateur n'est pas spécifié lors de l'exécution du fichier.

Donc, si vous exécutez ce qui suit:

$ python fully_connected_feed.py --learning_rate 1.00

alors le taux d'apprentissage est fixé à 1,00 et restera à 0,01 si le drapeau n'est pas spécifié.

Comme mentionné dans cet article , les documents ne sont probablement pas présents, car Google pourrait en exiger quelque chose en interne pour que ses développeurs puissent les utiliser.

De plus, comme mentionné dans l'article, l'utilisation des indicateurs Tensorflow par rapport aux fonctionnalités d'indicateur fournies par d'autres packages Python tels que argparse présente de nombreux avantages, en particulier pour les modèles Tensorflow, le plus important étant que vous peut fournir au code des informations spécifiques à Tensorflow, telles que des informations sur le GPU à utiliser.

24
Vedang Waradpande

Chez Google, ils utilisent des systèmes d'indicateurs pour définir les valeurs par défaut des arguments. C'est semblable à argparse. Ils utilisent leur propre système de drapeau au lieu d’argparse ou de sys.argv.

Source: j'y travaillais auparavant.

5
Ahmed

Lorsque vous utilisez tf.app.run(), vous pouvez facilement transférer la variable entre les threads à l'aide de tf.app.flags. Voir this pour plus d'informations sur tf.app.flags.

4
Cro

Après avoir essayé plusieurs fois, j’ai trouvé ceci pour imprimer toutes les clés FLAGS ainsi que la valeur réelle -

for key in tf.app.flags.FLAGS.flag_values_dict():

  print(key, FLAGS[key].value)
1
Nandani