web-dev-qa-db-fra.com

Transférer l'apprentissage avec tf.estimator.Estimator framework

J'essaie de transférer l'apprentissage d'un modèle Inception-resnet v2 pré-formé sur imagenet, en utilisant mon propre jeu de données et mes propres classes. Mon code d'origine était une modification d'un tf.slim échantillon que je ne trouve plus et maintenant j'essaye de réécrire le même code en utilisant le tf.estimator.* cadre.

Je me heurte cependant au problème de charger uniquement certains des poids du point de contrôle pré-formé, en initialisant les couches restantes avec leurs initialiseurs par défaut.

En recherchant le problème, j'ai trouvé ce problème GitHub et cette question , mentionnant tous deux la nécessité d'utiliser tf.train.init_from_checkpoint dans mon model_fn. J'ai essayé, mais étant donné le manque d'exemples dans les deux, je suppose que je me suis trompé.

Ceci est mon exemple minimal:

import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
import numpy as np

import inception_resnet_v2

NUM_CLASSES = 900
IMAGE_SIZE = 299

def input_fn(mode, num_classes, batch_size=1):
  # some code that loads images, reshapes them to 299x299x3 and batches them
  return tf.constant(np.zeros([batch_size, 299, 299, 3], np.float32)), tf.one_hot(tf.constant(np.zeros([batch_size], np.int32)), NUM_CLASSES)


def model_fn(images, labels, num_classes, mode):
  with tf.contrib.slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
    logits, end_points = inception_resnet_v2.inception_resnet_v2(images,
                                             num_classes, 
                                             is_training=(mode==tf.estimator.ModeKeys.TRAIN))
  predictions = {
      'classes': tf.argmax(input=logits, axis=1),
      'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
  }

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
  variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
  scopes = { os.path.dirname(v.name) for v in variables_to_restore }
  tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt',
                                {s+'/':s+'/' for s in scopes})

  tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
  total_loss = tf.losses.get_total_loss()    #obtain the regularization losses as well

  # Configure the training op
  if mode == tf.estimator.ModeKeys.TRAIN:
    global_step = tf.train.get_or_create_global_step()
    optimizer = tf.train.AdamOptimizer(learning_rate=0.00002)
    train_op = optimizer.minimize(total_loss, global_step)
  else:
    train_op = None

  return tf.estimator.EstimatorSpec(
    mode=mode,
    predictions=predictions,
    loss=total_loss,
    train_op=train_op)

def main(unused_argv):
  # Create the Estimator
  classifier = tf.estimator.Estimator(
      model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode),
      model_dir='model/MCVE')

  # Train the model  
  classifier.train(
      input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, batch_size=1),
      steps=1000)

  # Evaluate the model and print results
  eval_results = classifier.evaluate(
      input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, batch_size=1))
  print()
  print('Evaluation results:\n    %s' % eval_results)

if __name__ == '__main__':
  tf.app.run(main=main, argv=[sys.argv[0]])

inception_resnet_v2 est l'implémentation du modèle dans le référentiel de modèles de Tensorflow .

Si j'exécute ce script, j'obtiens un tas de journaux d'informations de init_from_checkpoint, mais au moment de la création de la session, il semble qu'il tente de charger les poids Logits à partir du point de contrôle et échoue en raison de formes incompatibles. Voici la trace complète:

Traceback (most recent call last):

  File "<ipython-input-6-06fadd69ae8f>", line 1, in <module>
    runfile('C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py', wdir='C:/Users/1/Desktop/transfer_learning_tutorial-master')

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile
    execfile(filename, namespace)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 77, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]])

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))

  File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 68, in main
    steps=1000)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 302, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 780, in _train_model
    log_step_count_steps=self._config.log_step_count_steps) as mon_sess:

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 368, in MonitoredTrainingSession
    stop_grace_period_secs=stop_grace_period_secs)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 673, in __init__
    stop_grace_period_secs=stop_grace_period_secs)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 493, in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 851, in __init__
    _WrappedSession.__init__(self, self._create_session())

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 856, in _create_session
    return self._sess_creator.create_session()

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 554, in create_session
    self.tf_sess = self._session_creator.create_session()

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 428, in create_session
    init_fn=self._scaffold.init_fn)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\session_manager.py", line 279, in prepare_session
    sess.run(init_op, feed_dict=init_feed_dict)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 889, in run
    run_metadata_ptr)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
    feed_dict_tensor, options, run_metadata)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run
    options, run_metadata)

  File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call
    raise type(e)(node_def, op, message)

InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [900] rhs shape= [1001]    [[Node: Assign_1145 = Assign[T=DT_FLOAT,
_class=["loc:@InceptionResnetV2/Logits/Logits/biases"], use_locking=true, validate_shape=true,
_device="/job:localhost/replica:0/task:0/device:CPU:0"](InceptionResnetV2/Logits/Logits/biases, checkpoint_initializer_1145)]]

Que fais-je de mal en utilisant init_from_checkpoint? Comment sommes-nous censés "l'utiliser" dans notre model_fn? Et pourquoi l'estimateur essaie-t-il de charger les poids Logits 'à partir du point de contrôle alors que je lui dis explicitement de ne pas le faire?

Mise à jour:

Après la suggestion dans les commentaires, j'ai essayé d'autres moyens d'appeler tf.train.init_from_checkpoint.

En utilisant {v.name: v.name}

Si, comme suggéré dans le commentaire, je remplace l'appel par {v.name:v.name for v in variables_to_restore}, J'obtiens cette erreur:

ValueError: Assignment map with scope only name InceptionResnetV2/Conv2d_2a_3x3 should map
to scope only InceptionResnetV2/Conv2d_2a_3x3/weights:0. Should be 'scope/': 'other_scope/'.

En utilisant {v.name: v}

Si, à la place, j'essaie d'utiliser le name:variable mapping, j'obtiens l'erreur suivante:

ValueError: Tensor InceptionResnetV2/Conv2d_2a_3x3/weights:0 is not found in
inception_resnet_v2_2016_08_30.ckpt checkpoint
{'InceptionResnetV2/Repeat_2/block8_4/Branch_1/Conv2d_0c_3x1/BatchNorm/moving_mean': [256], 
'InceptionResnetV2/Repeat/block35_9/Branch_0/Conv2d_1x1/BatchNorm/beta': [32], ...

L'erreur continue de répertorier ce que je pense être tous les noms de variables dans le point de contrôle (ou pourrait-il être les étendues à la place?).

Mise à jour (2)

Après avoir inspecté la dernière erreur ci-dessus, je constate que InceptionResnetV2/Conv2d_2a_3x3/weights est dans la liste des variables de point de contrôle. Le problème est que :0 à la fin! Je vais maintenant vérifier si cela résout effectivement le problème et poster une réponse si c'est le cas.

17
GPhilo

Grâce au commentaire de @ KathyWu, je me suis mis sur la bonne voie et j'ai trouvé le problème.

En effet, la façon dont je calculais le scopes inclurait le InceptionResnetV2/ scope, qui déclencherait le chargement de toutes les variables "sous" le scope (c'est-à-dire toutes les variables du réseau). Cependant, le remplacer par le bon dictionnaire n'était pas anodin.

Des modes d'étendue possibles init_from_checkpoint accepte , celui que j'ai dû utiliser était le 'scope_variable_name': variable un, mais sans utiliser le _ variable.name attribut .

Le variable.name ressemble à: 'some_scope/variable_name:0'. Que :0 n'est pas dans le nom de la variable point de contrôle et donc en utilisant scopes = {v.name:v.name for v in variables_to_restore} générera une erreur "Variable non trouvée".

L'astuce pour le faire fonctionner était de retirer l'index du tenseur du nom :

tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt', 
                              {v.name.split(':')[0]: v for v in variables_to_restore})
9
GPhilo

Je découvre {s+'/':s+'/' for s in scopes} n'a pas fonctionné, simplement parce que le variables_to_restore inclure quelque chose comme "global_step", les étendues incluent donc les étendues globales qui peuvent tout inclure. Vous devez imprimer variables_to_restore, trouver "global_step" chose, et le mettre dans "exclude".

1
hai lee