web-dev-qa-db-fra.com

comment implémenter un arrêt précoce dans tensorflow

def train():
# Model
model = Model()

# Loss, Optimizer
global_step = tf.Variable(1, dtype=tf.int32, trainable=False, name='global_step')
loss_fn = model.loss()
optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step)

# Summaries
summary_op = summaries(model, loss_fn)

with tf.Session(config=TrainConfig.session_conf) as sess:

    # Initialized, Load state
    sess.run(tf.global_variables_initializer())
    model.load_state(sess, TrainConfig.CKPT_PATH)

    writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph)

    # Input source
    data = Data(TrainConfig.DATA_PATH)

    loss = Diff()
    for step in xrange(global_step.eval(), TrainConfig.FINAL_STEP):

            mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE, step)

            mixed_spec = to_spectrogram(mixed_wav)
            mixed_mag = get_magnitude(mixed_spec)

            src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav)
            src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec)

            src1_batch, _ = model.spec_to_batch(src1_mag)
            src2_batch, _ = model.spec_to_batch(src2_mag)
            mixed_batch, _ = model.spec_to_batch(mixed_mag)

            # Initializae our callback.
            #early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.5)


            l, _, summary = sess.run([loss_fn, optimizer, summary_op],
                                     feed_dict={model.x_mixed: mixed_batch, model.y_src1: src1_batch,
                                                model.y_src2: src2_batch})

            loss.update(l)
            print('step-{}\td_loss={:2.2f}\tloss={}'.format(step, loss.diff * 100, loss.value))

            writer.add_summary(summary, global_step=step)

            # Save state
            if step % TrainConfig.CKPT_STEP == 0:
                tf.train.Saver().save(sess, TrainConfig.CKPT_PATH + '/checkpoint', global_step=step)

    writer.close()

J'ai ce code de réseau de neurones qui sépare la musique d'une voix dans un fichier .wav . Comment puis-je introduire un algorithme d'arrêt précoce pour arrêter la section de train? Je vois un projet qui parle d'un ValidationMonitor. Est-ce que quelqu'un peut m'aider?

10
Valerio

ValidationMonitor est marqué comme obsolète. ce n'est pas recommandé. mais vous pouvez toujours l’utiliser ... Voici un exemple de création:

    validation_monitor = monitors.ValidationMonitor(
        input_fn=functools.partial(input_fn, subset="evaluation"),
        eval_steps=128,
        every_n_steps=88,
        early_stopping_metric="accuracy",
        early_stopping_rounds = 1000
    )

et vous pouvez mettre en œuvre par vous-même, voici ma ma mise en œuvre:

          if (loss_value < self.best_loss):
            self.stopping_step = 0
            self.best_loss = loss_value
          else:
            self.stopping_step += 1
          if self.stopping_step >= FLAGS.early_stopping_step:
            self.should_stop = True
            print("Early stopping is trigger at step: {} loss:{}".format(global_step,loss_value))
            run_context.request_stop()
7
scott huang

Depuis la version TensorFlow r1.10, des crochets d'arrêt précoce sont disponibles pour l'API de l'estimateur dans early_stopping.py (voir github ).

Par exemple tf.contrib.estimator.stop_if_no_decrease_hook (voir docs )

0
mdh