D'après ce que j'ai recueilli jusqu'à présent, il existe différentes manières de transférer un graphique TensorFlow dans un fichier, puis de le charger dans un autre programme, mais je n'ai pas été en mesure de trouver des exemples/informations clairs sur leur fonctionnement. Ce que je sais déjà, c'est ceci:
tf.train.Saver()
et restaurez-les plus tard ( source )tf.train.write_graph()
et tf.import_graph_def()
( source ).as_graph_def()
pour enregistrer le modèle et pour les pondérations/variables, mappez-les en constantes ( source )Cependant, je n'ai pas pu éclaircir plusieurs questions concernant ces différentes méthodes:
tf.train.write_graph()
, les poids/variables sont-ils également enregistrés?tf.import_graph_def()
?as_graph_def()
/. Ckpt/.pb?En bref, ce que je recherche, c’est une méthode pour enregistrer à la fois un graphique (comme dans, les différentes opérations et autres) et ses pondérations/variables dans un fichier, qui peut ensuite être utilisé pour charger le graphique et les pondérations dans un autre programme. , pour utilisation (pas nécessairement formation continue/recyclage).
La documentation sur ce sujet n'étant pas très simple, toute réponse ou information serait grandement appréciée.
Il existe de nombreuses façons d’aborder le problème de la sauvegarde d’un modèle dans TensorFlow, ce qui peut le rendre un peu déroutant. Prenant chacune de vos sous-questions à tour de rôle:
Les fichiers de point de contrôle (produits par exemple en appelant saver.save()
sur un tf.train.Saver
objet) ne contient que les poids et toutes les autres variables définies dans le même programme. Pour les utiliser dans un autre programme, vous devez recréer la structure graphique associée (par exemple, en exécutant du code pour le reconstruire, ou en appelant tf.import_graph_def()
), qui indique à TensorFlow quoi faire avec ces poids. Notez que l'appel de saver.save()
génère également un fichier contenant un MetaGraphDef
, qui contient un graphique et des détails sur la manière d'associer les poids d'un point de contrôle à ce graphique. Voir le tutoriel pour plus de détails.
tf.train.write_graph()
écrit uniquement la structure du graphe; pas les poids.
Bazel n'a aucun lien avec la lecture ou l'écriture de graphiques TensorFlow. (Je comprends peut-être mal votre question: n'hésitez pas à la clarifier dans un commentaire.)
Un graphique figé peut être chargé avec tf.import_graph_def()
. Dans ce cas, les poids sont (généralement) incorporés dans le graphique, vous n'avez donc pas besoin de charger un point de contrôle séparé.
Le principal changement consisterait à mettre à jour les noms du ou des tenseurs introduits dans le modèle, ainsi que les noms du ou des tenseurs extraits du modèle. Dans la démo TensorFlow Android, cela correspondrait aux chaînes inputName
et outputName
qui sont transmises à TensorFlowClassifier.initializeTensorFlow()
.
GraphDef
est la structure du programme, qui ne change généralement pas au cours du processus de formation. Le point de contrôle est un instantané de l'état d'un processus de formation, qui change généralement à chaque étape du processus de formation. En conséquence, TensorFlow utilise différents formats de stockage pour ces types de données, et l'API de bas niveau fournit différentes manières de les enregistrer et de les charger. Bibliothèques de niveau supérieur, telles que les bibliothèques MetaGraphDef
, Keras et skflow s'appuie sur ces mécanismes pour fournir des moyens plus pratiques de sauvegarder et de restaurer un modèle entier.
Vous pouvez essayer le code suivant:
with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)