web-dev-qa-db-fra.com

Comment utiliser Model.fit qui prend en charge les générateurs (après la dépréciation de fit_generator)

J'ai reçu cet avertissement de dépréciation lors de l'utilisation de Model.fit_generator dans tensorflow:

WARNING:tensorflow: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.

Comment puis-je utiliser Model.fit au lieu de Model.fit_generator?

4
Behnam.B

Model.fit_generator est déconseillé à partir de tensorflow 2.1.0 qui se trouve actuellement dans rc1 . Vous pouvez trouver la documentation de tf-2.1.0-rc1 ici: https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/Model#fit

Comme vous pouvez voir le premier argument du Model.fit peut prendre un générateur alors passez-le simplement.

2
Wathek LOUED

Comme mentionné dans la documentation de tensorflow:

x: données d'entrée.

  1. Cela peut être: un tableau Numpy (ou semblable à un tableau), ou une liste de tableaux (dans le cas où le modèle a plusieurs entrées).
    1. Un tenseur TensorFlow ou une liste de tenseurs (dans le cas où le modèle a plusieurs entrées).
    2. Un dict mappant les noms des entrées au tableau/tenseurs correspondant, si le modèle a nommé des entrées.
    3. Un ensemble de données tf.data. Doit renvoyer un tuple de (entrées, cibles) ou (entrées, cibles, sample_weights)
    4. Un générateur ou keras.utils.Sequence retournant (entrées, cibles) ou (entrées, cibles, poids d'échantillon). Une description plus détaillée du comportement de décompression pour les types d'itérateurs (Dataset, générateur, séquence) est donnée ci-dessous.

vous pouvez simplement passer le générateur à Model.fit comme similaire à Model.fit_generator

data_gen_train = ImageDataGenerator(rescale=1/255.)

data_gen_valid = ImageDataGenerator(rescale=1/255.)

train_generator = data_gen_train.flow_from_directory(train_dir, target_size=(128,128), batch_size=128, class_mode="binary")

valid_generator = data_gen_valid.flow_from_directory(validation_dir, target_size=(128,128), batch_size=128, class_mode="binary")

model.fit(train_generator, epochs=2, validation_data=valid_generator).

0
Anurag Mishra