Je veux créer des enregistrements tensorflow pour alimenter mon modèle; Jusqu'à présent, j'utilise le code suivant pour stocker le tableau uint8 numpy au format TFRecord;
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def convert_to_record(name, image, label, map):
filename = os.path.join(params.TRAINING_RECORDS_DATA_DIR, name + '.' + params.DATA_EXT)
writer = tf.python_io.TFRecordWriter(filename)
image_raw = image.tostring()
map_raw = map.tostring()
label_raw = label.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw': _bytes_feature(image_raw),
'map_raw': _bytes_feature(map_raw),
'label_raw': _bytes_feature(label_raw)
}))
writer.write(example.SerializeToString())
writer.close()
que j'ai lu avec cet exemple de code
features = tf.parse_single_example(example, features={
'image_raw': tf.FixedLenFeature([], tf.string),
'map_raw': tf.FixedLenFeature([], tf.string),
'label_raw': tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape(params.IMAGE_HEIGHT*params.IMAGE_WIDTH*3)
image = tf.reshape(image_, (params.IMAGE_HEIGHT,params.IMAGE_WIDTH,3))
map = tf.decode_raw(features['map_raw'], tf.uint8)
map.set_shape(params.MAP_HEIGHT*params.MAP_WIDTH*params.MAP_DEPTH)
map = tf.reshape(map, (params.MAP_HEIGHT,params.MAP_WIDTH,params.MAP_DEPTH))
label = tf.decode_raw(features['label_raw'], tf.uint8)
label.set_shape(params.NUM_CLASSES)
et ça marche bien. Maintenant, je veux faire la même chose avec mon tableau "map" étant un tableau float numpy, au lieu de uint8, et je ne pouvais pas trouver d’exemples sur la façon de le faire; J’ai essayé la fonction _floats_feature, qui fonctionne si je passe un scalar à cela, mais pas avec des tableaux; avec uint8 la sérialisation peut être faite par la méthode tostring ();
Comment puis-je sérialiser un tableau float numpy et comment le relire?
FloatList
et BytesList
s'attendent à un résultat itératif. Vous devez donc lui transmettre une liste de flotteurs. Supprimez les crochets supplémentaires dans votre _float_feature
, c.-à-d.
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
numpy_arr = np.ones((3,)).astype(np.float)
example = tf.train.Example(features=tf.train.Features(feature={"bytes": _floats_feature(numpy_arr)}))
print(example)
features {
feature {
key: "bytes"
value {
float_list {
value: 1.0
value: 1.0
value: 1.0
}
}
}
}
Je vais développer la réponse de Yaroslav.
Int64List, BytesList et FloatList attendent un itérateur des éléments sous-jacents (champ répété). Dans votre cas, vous pouvez utiliser une liste en tant qu’itérateur.
Vous avez mentionné: cela fonctionne si je lui passe un scalaire, mais pas avec les tableaux. Et cela est attendu, car lorsque vous passez un scalaire, votre _floats_feature
crée un tableau d'un élément float (exactement comme prévu). Mais lorsque vous passez un tableau, vous créez une liste de tableaux et le transmettez à une fonction qui attend une liste de flottants.
Donc, supprimez simplement la construction du tableau de votre fonction: float_list=tf.train.FloatList(value=value)
L'exemple de Yaroslav a échoué quand un tableau nd était l'entrée:
numpy_arr = np.ones ((3,3)). astype (np.float)
J'ai trouvé que cela fonctionnait lorsque j'ai utilisé numpy_arr.ravel () comme entrée. Mais y a-t-il une meilleure façon de le faire?
Je suis tombé sur cela alors que je travaillais sur un problème similaire. Puisqu'une partie de la question initiale consistait à savoir comment relire la fonctionnalité float32
de tfrecords
, je la laisserai ici au cas où cela aiderait quelqu'un:
Si map.ravel()
a été utilisé pour entrer map
de dimensions [x, y, z]
dans _floats_feature
:
features = {
...
'map': tf.FixedLenFeature([x, y, z], dtype=tf.float32)
...
}
parsed_example = tf.parse_single_example(serialized=serialized, features=features)
map = parsed_example['map']