J'essaie de transférer l'apprentissage d'un modèle Inception-resnet v2 pré-formé sur imagenet, en utilisant mon propre jeu de données et mes propres classes. Mon code d'origine était une modification d'un tf.slim
échantillon que je ne trouve plus et maintenant j'essaye de réécrire le même code en utilisant le tf.estimator.*
cadre.
Je me heurte cependant au problème de charger uniquement certains des poids du point de contrôle pré-formé, en initialisant les couches restantes avec leurs initialiseurs par défaut.
En recherchant le problème, j'ai trouvé ce problème GitHub et cette question , mentionnant tous deux la nécessité d'utiliser tf.train.init_from_checkpoint
dans mon model_fn
. J'ai essayé, mais étant donné le manque d'exemples dans les deux, je suppose que je me suis trompé.
Ceci est mon exemple minimal:
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
import numpy as np
import inception_resnet_v2
NUM_CLASSES = 900
IMAGE_SIZE = 299
def input_fn(mode, num_classes, batch_size=1):
# some code that loads images, reshapes them to 299x299x3 and batches them
return tf.constant(np.zeros([batch_size, 299, 299, 3], np.float32)), tf.one_hot(tf.constant(np.zeros([batch_size], np.int32)), NUM_CLASSES)
def model_fn(images, labels, num_classes, mode):
with tf.contrib.slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
logits, end_points = inception_resnet_v2.inception_resnet_v2(images,
num_classes,
is_training=(mode==tf.estimator.ModeKeys.TRAIN))
predictions = {
'classes': tf.argmax(input=logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
scopes = { os.path.dirname(v.name) for v in variables_to_restore }
tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt',
{s+'/':s+'/' for s in scopes})
tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well
# Configure the training op
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(learning_rate=0.00002)
train_op = optimizer.minimize(total_loss, global_step)
else:
train_op = None
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=total_loss,
train_op=train_op)
def main(unused_argv):
# Create the Estimator
classifier = tf.estimator.Estimator(
model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode),
model_dir='model/MCVE')
# Train the model
classifier.train(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, batch_size=1),
steps=1000)
# Evaluate the model and print results
eval_results = classifier.evaluate(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, batch_size=1))
print()
print('Evaluation results:\n %s' % eval_results)
if __name__ == '__main__':
tf.app.run(main=main, argv=[sys.argv[0]])
où inception_resnet_v2
est l'implémentation du modèle dans le référentiel de modèles de Tensorflow .
Si j'exécute ce script, j'obtiens un tas de journaux d'informations de init_from_checkpoint
, mais au moment de la création de la session, il semble qu'il tente de charger les poids Logits
à partir du point de contrôle et échoue en raison de formes incompatibles. Voici la trace complète:
Traceback (most recent call last):
File "<ipython-input-6-06fadd69ae8f>", line 1, in <module>
runfile('C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py', wdir='C:/Users/1/Desktop/transfer_learning_tutorial-master')
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile
execfile(filename, namespace)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 77, in <module>
tf.app.run(main=main, argv=[sys.argv[0]])
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 68, in main
steps=1000)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 302, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 780, in _train_model
log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 368, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 673, in __init__
stop_grace_period_secs=stop_grace_period_secs)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 493, in __init__
self._sess = _RecoverableSession(self._coordinated_creator)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 851, in __init__
_WrappedSession.__init__(self, self._create_session())
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 856, in _create_session
return self._sess_creator.create_session()
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 554, in create_session
self.tf_sess = self._session_creator.create_session()
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 428, in create_session
init_fn=self._scaffold.init_fn)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\session_manager.py", line 279, in prepare_session
sess.run(init_op, feed_dict=init_feed_dict)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 889, in run
run_metadata_ptr)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run
options, run_metadata)
File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call
raise type(e)(node_def, op, message)
InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [900] rhs shape= [1001] [[Node: Assign_1145 = Assign[T=DT_FLOAT,
_class=["loc:@InceptionResnetV2/Logits/Logits/biases"], use_locking=true, validate_shape=true,
_device="/job:localhost/replica:0/task:0/device:CPU:0"](InceptionResnetV2/Logits/Logits/biases, checkpoint_initializer_1145)]]
Que fais-je de mal en utilisant init_from_checkpoint
? Comment sommes-nous censés "l'utiliser" dans notre model_fn
? Et pourquoi l'estimateur essaie-t-il de charger les poids Logits
'à partir du point de contrôle alors que je lui dis explicitement de ne pas le faire?
Après la suggestion dans les commentaires, j'ai essayé d'autres moyens d'appeler tf.train.init_from_checkpoint
.
{v.name: v.name}
Si, comme suggéré dans le commentaire, je remplace l'appel par {v.name:v.name for v in variables_to_restore}
, J'obtiens cette erreur:
ValueError: Assignment map with scope only name InceptionResnetV2/Conv2d_2a_3x3 should map
to scope only InceptionResnetV2/Conv2d_2a_3x3/weights:0. Should be 'scope/': 'other_scope/'.
{v.name: v}
Si, à la place, j'essaie d'utiliser le name:variable
mapping, j'obtiens l'erreur suivante:
ValueError: Tensor InceptionResnetV2/Conv2d_2a_3x3/weights:0 is not found in
inception_resnet_v2_2016_08_30.ckpt checkpoint
{'InceptionResnetV2/Repeat_2/block8_4/Branch_1/Conv2d_0c_3x1/BatchNorm/moving_mean': [256],
'InceptionResnetV2/Repeat/block35_9/Branch_0/Conv2d_1x1/BatchNorm/beta': [32], ...
L'erreur continue de répertorier ce que je pense être tous les noms de variables dans le point de contrôle (ou pourrait-il être les étendues à la place?).
Après avoir inspecté la dernière erreur ci-dessus, je constate que InceptionResnetV2/Conv2d_2a_3x3/weights
est dans la liste des variables de point de contrôle. Le problème est que :0
à la fin! Je vais maintenant vérifier si cela résout effectivement le problème et poster une réponse si c'est le cas.
Grâce au commentaire de @ KathyWu, je me suis mis sur la bonne voie et j'ai trouvé le problème.
En effet, la façon dont je calculais le scopes
inclurait le InceptionResnetV2/
scope, qui déclencherait le chargement de toutes les variables "sous" le scope (c'est-à-dire toutes les variables du réseau). Cependant, le remplacer par le bon dictionnaire n'était pas anodin.
Des modes d'étendue possibles init_from_checkpoint
accepte , celui que j'ai dû utiliser était le 'scope_variable_name': variable
un, mais sans utiliser le _ variable.name
attribut .
Le variable.name
ressemble à: 'some_scope/variable_name:0'
. Que :0
n'est pas dans le nom de la variable point de contrôle et donc en utilisant scopes = {v.name:v.name for v in variables_to_restore}
générera une erreur "Variable non trouvée".
L'astuce pour le faire fonctionner était de retirer l'index du tenseur du nom :
tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt',
{v.name.split(':')[0]: v for v in variables_to_restore})
Je découvre {s+'/':s+'/' for s in scopes}
n'a pas fonctionné, simplement parce que le variables_to_restore
inclure quelque chose comme "global_step"
, les étendues incluent donc les étendues globales qui peuvent tout inclure. Vous devez imprimer variables_to_restore
, trouver "global_step"
chose, et le mettre dans "exclude"
.