def train():
# Model
model = Model()
# Loss, Optimizer
global_step = tf.Variable(1, dtype=tf.int32, trainable=False, name='global_step')
loss_fn = model.loss()
optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step)
# Summaries
summary_op = summaries(model, loss_fn)
with tf.Session(config=TrainConfig.session_conf) as sess:
# Initialized, Load state
sess.run(tf.global_variables_initializer())
model.load_state(sess, TrainConfig.CKPT_PATH)
writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph)
# Input source
data = Data(TrainConfig.DATA_PATH)
loss = Diff()
for step in xrange(global_step.eval(), TrainConfig.FINAL_STEP):
mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE, step)
mixed_spec = to_spectrogram(mixed_wav)
mixed_mag = get_magnitude(mixed_spec)
src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav)
src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec)
src1_batch, _ = model.spec_to_batch(src1_mag)
src2_batch, _ = model.spec_to_batch(src2_mag)
mixed_batch, _ = model.spec_to_batch(mixed_mag)
# Initializae our callback.
#early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.5)
l, _, summary = sess.run([loss_fn, optimizer, summary_op],
feed_dict={model.x_mixed: mixed_batch, model.y_src1: src1_batch,
model.y_src2: src2_batch})
loss.update(l)
print('step-{}\td_loss={:2.2f}\tloss={}'.format(step, loss.diff * 100, loss.value))
writer.add_summary(summary, global_step=step)
# Save state
if step % TrainConfig.CKPT_STEP == 0:
tf.train.Saver().save(sess, TrainConfig.CKPT_PATH + '/checkpoint', global_step=step)
writer.close()
J'ai ce code de réseau de neurones qui sépare la musique d'une voix dans un fichier .wav . Comment puis-je introduire un algorithme d'arrêt précoce pour arrêter la section de train? Je vois un projet qui parle d'un ValidationMonitor. Est-ce que quelqu'un peut m'aider?
ValidationMonitor est marqué comme obsolète. ce n'est pas recommandé. mais vous pouvez toujours l’utiliser ... Voici un exemple de création:
validation_monitor = monitors.ValidationMonitor(
input_fn=functools.partial(input_fn, subset="evaluation"),
eval_steps=128,
every_n_steps=88,
early_stopping_metric="accuracy",
early_stopping_rounds = 1000
)
et vous pouvez mettre en œuvre par vous-même, voici ma ma mise en œuvre:
if (loss_value < self.best_loss):
self.stopping_step = 0
self.best_loss = loss_value
else:
self.stopping_step += 1
if self.stopping_step >= FLAGS.early_stopping_step:
self.should_stop = True
print("Early stopping is trigger at step: {} loss:{}".format(global_step,loss_value))
run_context.request_stop()