web-dev-qa-db-fra.com

Comment corriger: "google.protobuf.message.decodeerror: erreur d'analyse d'erreur" lors de la création de résumé du texte TENSORFLOW

J'essaie d'exécuter un script pour obtenir un résumé du texte d'un modèle Tensorflow .pb telle que ceci:

    OPS counts:
    Squeeze : 1
    Softmax : 1
    BiasAdd : 1
    Placeholder : 1
    AvgPool : 1
    Reshape : 2
    ConcatV2 : 9
    MaxPool : 13
    Sub : 57
    Rsqrt : 57
    Relu : 57
    Conv2D : 58
    Add : 114
    Mul : 114
    Identity : 231
    Const : 298

J'essaie globalement de convertir un modèle .pb vers un .coremlmodel et je suis à la suite de cet article:

https://hakernoon.com/intgrating-tensorflow-model-in-an-ios-app-cecf30b9068d

Obtenir un résumé du texte du modèle .pb est un pas en avant vers cela. Le code que j'essaie de courir pour créer le résumé du texte suit:

import tensorflow as tf
from tensorflow.core.framework import graph_pb2
import time
import operator
import sys

def inspect(model_pb, output_txt_file):
    graph_def = graph_pb2.GraphDef()
    with open(model_pb, "rb") as f:
    graph_def.ParseFromString(f.read())

    tf.import_graph_def(graph_def)

    sess = tf.Session()
    OPS = sess.graph.get_operations()

    ops_dict = {}

    sys.stdout = open(output_txt_file, 'w')
    for i, op in enumerate(OPS):
            print('---------------------------------------------------------------------------------------------------------------------------------------------')
        print("{}: op name = {}, op type = ( {} ), inputs = {}, outputs = {}".format(i, op.name, op.type, ", ".join([x.name for x in op.inputs]), ", ".join([x.name for x in op.outputs])))
        print('@input shapes:')
        for x in op.inputs:
            print("name = {} : {}".format(x.name, x.get_shape()))
        print('@output shapes:')
        for x in op.outputs:
            print("name = {} : {}".format(x.name, x.get_shape()))
        if op.type in ops_dict:
            ops_dict[op.type] += 1
        else:
            ops_dict[op.type] = 1

    print('---------------------------------------------------------------------------------------------------------------------------------------------')
sorted_ops_count = sorted(ops_dict.items(),     key=operator.itemgetter(1))
    print('OPS counts:')
    for i in sorted_ops_count:
        print("{} : {}".format(i[0], i[1]))

if __name__ == "__main__":
"""
Write a summary of the frozen TF graph to a text file.
Summary includes op name, type, input and output names and shapes. 

Arguments
----------
- path to the frozen .pb graph
- path to the output .txt file where the summary is written

Usage
----------
python inspect_pb.py frozen.pb text_file.txt

"""
if len(sys.argv) != 3:
    raise ValueError("Script expects two arguments. " +
          "Usage: python inspect_pb.py /path/to/the/frozen.pb /path/to/the/output/text/file.txt")
inspect(sys.argv[1], sys.argv[2])

J'ai couru cette commande:

 python inspect_pb.py /Users/nikhil.c/Desktop/tensorflowModel.pb   text_summary.txt

Mais au lieu de recevoir la sortie attendue, je reçois ce message d'erreur:

    Traceback (most recent call last):
      File "inspect_pb.py", line 58, in <module>
        inspect(sys.argv[1], sys.argv[2])
      File "inspect_pb.py", line 10, in inspect
        graph_def.ParseFromString(f.read())
    google.protobuf.message.DecodeError: Error parsing message

et ne sait pas vraiment où commencer. D'autres questions similaires qui semblent recevoir le même message d'erreur n'ont pas trop de sens. Que fais-je?

3
Nikhil Chandra

Tu peux essayer:

with gfile.GFile(frozen_graph, 'rb') as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read(-1)) 
    # the -1 returns the contents of the file as a string

Je l'ai vu dans la documentation: https://www.tensorflow.org/api_docs/python/tf/io/gfile/gfile#Read

0
Pavlos Sakoglou