web-dev-qa-db-fra.com

Keras modèle d'apprentissage en profondeur pour Android

Je développe une application de classification d'objets en temps réel pour Android. J'ai d'abord créé un modèle d'apprentissage approfondi en utilisant "keras" et j'ai déjà un modèle formé enregistré sous le fichier "model.h5". J'aimerais savoir comment utiliser ce modèle sous Android pour la classification des images. 

6
Vajira Prabuddhaka

Vous ne pouvez pas exporter Keras directement sur Android, mais vous devez enregistrer le modèle. 

  • Configurez Tensflow en tant que votre backend Keras.

  • Enregistrez les paramètres du modèle en utilisant model.save(filepath) (vous l'avez déjà fait)

Puis chargez-le avec l'une des solutions suivantes:

Solution 1: Modèle d'importation dans Tensflow

1- Construire le modèle Tensorflow

2- Construire une application Android et appeler tensflow. Cochez cette tutoriel et cette démo officielle de google pour apprendre à le faire.

Solution 2: Modèle d'importation en Java 
1- deeplearning4j une librairie Java permet d’importer le modèle keras: lien tutoriel
2- Utiliser deeplearning4j dans Android: c'est facile puisque vous êtes dans le monde Java. vérifier ce tutoriel

7
Alex

Vous devez d’abord exporter le modèle Keras vers un modèle Tensorflow:

def export_model_for_mobile(model_name, input_node_names, output_node_name):
    tf.train.write_graph(K.get_session().graph_def, 'out', \
        model_name + '_graph.pbtxt')

    tf.train.Saver().save(K.get_session(), 'out/' + model_name + '.chkp')

    freeze_graph.freeze_graph('out/' + model_name + '_graph.pbtxt', None, \
        False, 'out/' + model_name + '.chkp', output_node_name, \
        "save/restore_all", "save/Const:0", \
        'out/frozen_' + model_name + '.pb', True, "")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + model_name + '.pb', "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, input_node_names, [output_node_name],
            tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('out/tensorflow_lite_' + model_name + '.pb', "wb") as f:
        f.write(output_graph_def.SerializeToString())

Vous avez juste besoin de connaître le input_nodes_names et le output_node_names de votre graphique. Cela créera un nouveau dossier avec plusieurs fichiers. Parmi eux, on commence par tensorflow_lite_. C'est le fichier que vous allez déplacer sur votre appareil Android.

Importez ensuite la bibliothèque Tensorflow sur Android et utilisez TensorFlowInferenceInterface pour exécuter votre modèle.

implementation 'org.tensorflow:tensorflow-Android:1.5.0'

Vous pouvez vérifier mon exemple XOR simple sur Github:

https://github.com/OmarAflak/Keras-Android-XOR

1
Omar Aflak