web-dev-qa-db-fra.com

AVERTISSEMENT: tensorflow: les modes sample_weight ont été contraints de ... à ['...']

Formation d'un classificateur d'images à l'aide de .fit_generator() ou .fit() et transmission d'un dictionnaire à class_weight= comme argument.

Je n'ai jamais eu d'erreurs dans TF1.x mais en 2.1 j'obtiens la sortie suivante lors du démarrage de la formation:

WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']

Que signifie contraindre quelque chose à partir de ... à ['...']?

La source de cet avertissement sur le référentiel de tensorflow est ici , les commentaires placés sont:

Essayez de contraindre sample_weight_modes à la structure cible. Cela dépend implicitement du fait que le modèle aplatit les sorties pour sa représentation interne.

44
jorijnsmit

Cela ressemble à un faux message. Je reçois le même message d'avertissement après la mise à niveau vers TensorFlow 2.1, mais je n'utilise aucun poids de classe ou exemple de poids. J'utilise un générateur qui renvoie un tuple comme celui-ci:

return inputs, targets

Et maintenant, je viens de le changer comme suit pour faire disparaître l'avertissement:

return inputs, targets, [None]

Je ne sais pas si cela est pertinent, mais mon modèle utilise 3 entrées, donc ma variable inputs est en fait une liste de 3 tableaux numpy. targets n'est qu'un simple tableau numpy.

En tout cas, ce n'est qu'un avertissement. La formation fonctionne bien dans les deux cas.

Modifier pour TensorFlow 2.2:

Ce bogue semble avoir été corrigé dans TensorFlow 2.2, ce qui est génial. Cependant, le correctif ci-dessus échouera dans TF 2.2, car il essaiera d'obtenir la forme des poids d'échantillon, ce qui échouera évidemment avec AttributeError: 'NoneType' object has no attribute 'shape'. Annulez donc le correctif ci-dessus lors de la mise à niveau vers 2.2.

6
jlh

Je pense que c'est un bogue avec tensorflow qui se produit lorsque vous appelez model.compile() avec le paramètre par défaut sample_weight_mode=None Puis appelez model.fit() avec sample_weight Ou class_weight.

Depuis les dépôts tensorflow:

  • fit() appelle éventuellement _process_training_inputs()
  • _process_training_inputs()définitsample_weight_modes = [None] basé sur model.sample_weight_mode = None puis crée un DataAdapter avec sample_weight_modes = [None]
  • DataAdapter appelle broadcast_sample_weight_modes() avec sample_weight_modes = [None] pendant initialisation
  • broadcast_sample_weight_modes()semble attendresample_weight_modes = None mais reçoit [None]
  • il affirme que [None] est une structure différente de sample_weight/class_weight, la remplace par None en s'adaptant à la structure de sample_weight/class_weight Et affiche un avertissement

Attention, cela n'a aucun effet sur fit() car sample_weight_modes Dans DataAdapter est redéfini sur None.

Notez que tensorflow documentation indique que sample_weight Doit être un tableau numpy. Si vous appelez fit() avec sample_weight.tolist() à la place, vous n'obtiendrez pas d'avertissement mais sample_weight Est silencieusement remplacé par None lorsque _process_numpy_inputs() est appelé dans prétraitement et reçoit une entrée de longueur supérieure à un.

10
Max

J'ai pris votre Gist et installé Tensorflow 2.0, au lieu de TFA et cela a fonctionné sans un tel avertissement.

Voici le Gist du code complet. Le code d'installation de Tensorflow est indiqué ci-dessous:

!pip install tensorflow==2.0

Une capture d'écran de l'exécution réussie est présentée ci-dessous:

enter image description here

Mise à jour: Ce bogue est corrigé dans Tensorflow Version 2.2.

4
Tensorflow Support