J'utilise le code du tutoriel de https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py et le code fonctionne bien jusqu'à ce que j'essaie de faire une prédiction au lieu de simplement l'évaluer. J'ai essayé de créer une autre fonction de prédiction qui ressemble à ceci (en supprimant simplement le paramètre y):
def input_fn_predict(data_file, num_epochs, shuffle):
"""Input builder function."""
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine="python",
skiprows=1)
# remove NaN elements
df_data = df_data.dropna(how="any", axis=0)
labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
return tf.estimator.inputs.pandas_input_fn( #removed paramter y
x=df_data,
batch_size=100,
num_epochs=num_epochs,
shuffle=shuffle,
num_threads=5)
Et pour l'appeler comme ça:
predictions = m.predict(
input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
)
for i, p in enumerate(predictions):
print(i, p)
{'probabilités': tableau ([0.78595656, 0.21404342], dtype = float32), 'logits': array ([- 1.3007226], dtype = float32), 'classes': array (['0'], dtype = object) , 'class_ids': tableau ([0]), 'logistic': tableau ([0.21404341], dtype = float32)}
Comment lire ça?
Vous devez définir shuffle=False
puisque pour prévoir une nouvelle étiquette, vous devez maintenir l'ordre des données.
Voici mon code pour exécuter la prédiction (je l'ai testée). Le fichier d'entrée est comme les données de test (en csv), mais il n'y a pas de colonne d'étiquette.
def predict_input_fn(data_file):
global CSV_COLUMNS
CSV_COLUMNS = CSV_COLUMNS[:-1]
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine='python',
skiprows=1
)
# remove NaN elements
df_data = df_data.dropna(how='any', axis=0)
return tf.estimator.inputs.pandas_input_fn(
x=df_data,
num_epochs=1,
shuffle=False
)
Pour l'appeler:
predict_file_name = 'tutorials/data/adult.predict'
results = m.predict(
input_fn=predict_input_fn(predict_file_name)
)
for result in results:
print 'result: {}'.format(result)
Le résultat de prédiction pour un échantillon est ci-dessous:
{
'probabilities': array([0.78595656, 0.21404342], dtype = float32),
'logits': array([-1.3007226], dtype = float32),
'classes': array(['0'], dtype = object),
'class_ids': array([0]),
'logistic': array([0.21404341], dtype = float32)
}
Que signifie chaque champ