web-dev-qa-db-fra.com

Prédire dans l'estimateur Tensorflow en utilisant l'entrée fn

J'utilise le code du tutoriel de https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py et le code fonctionne bien jusqu'à ce que j'essaie de faire une prédiction au lieu de simplement l'évaluer. J'ai essayé de créer une autre fonction de prédiction qui ressemble à ceci (en supprimant simplement le paramètre y):

def input_fn_predict(data_file, num_epochs, shuffle):
  """Input builder function."""
  df_data = pd.read_csv(
      tf.gfile.Open(data_file),
      names=CSV_COLUMNS,
      skipinitialspace=True,
      engine="python",
      skiprows=1)
  # remove NaN elements
  df_data = df_data.dropna(how="any", axis=0)
  labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
  return tf.estimator.inputs.pandas_input_fn( #removed paramter y
      x=df_data,
      batch_size=100,
      num_epochs=num_epochs,
      shuffle=shuffle,
      num_threads=5)

Et pour l'appeler comme ça:

predictions = m.predict(
      input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
  )
  for i, p in enumerate(predictions):
      print(i, p)
  • Suis-je en train de bien faire?
  • Pourquoi est-ce que j'obtiens la prédiction 81404 au lieu de 16282 (nombre de lignes dans le fichier de test)?
  • Chaque ligne contient quelque chose comme ceci:

{'probabilités': tableau ([0.78595656, 0.21404342], dtype = float32), 'logits': array ([- 1.3007226], dtype = float32), 'classes': array (['0'], dtype = object) , 'class_ids': tableau ([0]), 'logistic': tableau ([0.21404341], dtype = float32)}

Comment lire ça?

6
Gregorius Edwadr

Vous devez définir shuffle=False puisque pour prévoir une nouvelle étiquette, vous devez maintenir l'ordre des données.

Voici mon code pour exécuter la prédiction (je l'ai testée). Le fichier d'entrée est comme les données de test (en csv), mais il n'y a pas de colonne d'étiquette.



    def predict_input_fn(data_file):
        global CSV_COLUMNS
        CSV_COLUMNS = CSV_COLUMNS[:-1]
        df_data = pd.read_csv(
            tf.gfile.Open(data_file),
            names=CSV_COLUMNS,
            skipinitialspace=True,
            engine='python',
            skiprows=1
        )

        # remove NaN elements
        df_data = df_data.dropna(how='any', axis=0)

        return tf.estimator.inputs.pandas_input_fn(
            x=df_data,
            num_epochs=1,
           shuffle=False
        )

Pour l'appeler:



    predict_file_name = 'tutorials/data/adult.predict'
    results = m.predict(
        input_fn=predict_input_fn(predict_file_name)
    )
    for result in results:
        print 'result: {}'.format(result)

Le résultat de prédiction pour un échantillon est ci-dessous:



    {
        'probabilities': array([0.78595656, 0.21404342], dtype = float32),
        'logits': array([-1.3007226], dtype = float32),
        'classes': array(['0'], dtype = object),
        'class_ids': array([0]),
        'logistic': array([0.21404341], dtype = float32)
    }

Que signifie chaque champ

  • 'probabilités': tableau ([0.78595656, 0.21404342], dtype = float32).
    Il prédit que l'étiquette de sortie est de classe 0 (dans ce cas <= 50K) avec confiance 0,78595656
  • 'logits': tableau ([- 1.3007226], dtype = float32)
    La valeur de z dans l'équation 1/(1 + e ^ (- z)) est -1,3.
  • 'classes': tableau (['0'], dtype = objet)
    Le libellé de la classe est 0
16
impulse