Je lis des exemples de codes dans Tensorflow, j'ai trouvé le code suivant
_flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.')
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
flags.DEFINE_integer('batch_size', 100, 'Batch size. '
'Must divide evenly into the dataset sizes.')
flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
'for unit testing.')
_
en _tensorflow/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py
_
Mais je ne trouve aucune documentation sur cette utilisation de _tf.app.flags
_.
Et j’ai trouvé que la mise en oeuvre de ces drapeaux se trouve dans tensorflow/tensorflow/python/platform/default/_flags.py
Évidemment, ce _tf.app.flags
_ est en quelque sorte utilisé pour configurer un réseau, alors pourquoi n'est-il pas dans la documentation de l'API? Quelqu'un peut-il expliquer ce qui se passe ici?
Le module _tf.app.flags
_ est actuellement un wrapper mince autour de python-gflags, donc la documentation de ce projet est la meilleure ressource pour son utilisation argparse
, qui implémente un sous-ensemble de la fonctionnalité dans python-gflags
.
Notez que ce module est actuellement fourni pour faciliter l'écriture d'applications de démonstration et ne fait pas techniquement partie de l'API publique. Il peut donc être modifié à l'avenir.
Nous vous recommandons de mettre en œuvre votre propre analyse des indicateurs à l'aide de argparse
ou de la bibliothèque de votre choix.
EDIT: Le module _tf.app.flags
_ n'est pas implémenté à l'aide de _python-gflags
_, mais il utilise une API similaire.
Le module tf.app.flags
est une fonctionnalité fournie par Tensorflow pour implémenter des indicateurs de ligne de commande pour votre programme Tensorflow. Par exemple, le code que vous avez rencontré ferait ce qui suit:
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
Le premier paramètre définit le nom de l'indicateur, tandis que le second définit la valeur par défaut si l'indicateur n'est pas spécifié lors de l'exécution du fichier.
Donc, si vous exécutez ce qui suit:
$ python fully_connected_feed.py --learning_rate 1.00
alors le taux d'apprentissage est fixé à 1,00 et restera à 0,01 si le drapeau n'est pas spécifié.
Comme mentionné dans cet article , les documents ne sont probablement pas présents, car Google pourrait en exiger quelque chose en interne pour que ses développeurs puissent les utiliser.
De plus, comme mentionné dans l'article, l'utilisation des indicateurs Tensorflow par rapport aux fonctionnalités d'indicateur fournies par d'autres packages Python tels que argparse
présente de nombreux avantages, en particulier pour les modèles Tensorflow, le plus important étant que vous peut fournir au code des informations spécifiques à Tensorflow, telles que des informations sur le GPU à utiliser.
Chez Google, ils utilisent des systèmes d'indicateurs pour définir les valeurs par défaut des arguments. C'est semblable à argparse. Ils utilisent leur propre système de drapeau au lieu d’argparse ou de sys.argv.
Source: j'y travaillais auparavant.
Lorsque vous utilisez tf.app.run()
, vous pouvez facilement transférer la variable entre les threads à l'aide de tf.app.flags
. Voir this pour plus d'informations sur tf.app.flags
.
Après avoir essayé plusieurs fois, j’ai trouvé ceci pour imprimer toutes les clés FLAGS ainsi que la valeur réelle -
for key in tf.app.flags.FLAGS.flag_values_dict():
print(key, FLAGS[key].value)