web-dev-qa-db-fra.com

TensorFlow: Mémorisez l'état LSTM pour le lot suivant (LSTM avec état)

Étant donné un modèle LSTM formé, je veux effectuer l'inférence pour des pas de temps uniques, c'est-à-dire seq_length = 1 dans l'exemple ci-dessous. Après chaque pas de temps, les états LSTM internes (mémoire et cachés) doivent être mémorisés pour le prochain "lot". Pour le tout début de l'inférence, le LSTM interne indique init_c, init_h sont calculés en fonction de l'entrée. Ceux-ci sont ensuite stockés dans un objet LSTMStateTuple qui est transmis au LSTM. Pendant l'entraînement, cet état est mis à jour à chaque pas de temps. Cependant, pour l'inférence, je veux que le state soit enregistré entre les lots, c'est-à-dire que les états initiaux doivent seulement être calculés au tout début et ensuite que les états LSTM doivent être enregistrés après chaque "lot" (n = 1 ).

J'ai trouvé cette question liée à StackOverflow: Tensorflow, meilleure façon de sauvegarder l'état dans les RNN? . Cependant, cela ne fonctionne que si state_is_Tuple=False, mais ce comportement sera bientôt déconseillé par TensorFlow (voir rnn_cell.py ). Keras semble avoir un joli wrapper pour rendre LSTM avec état possible mais je ne connais pas la meilleure façon d'y parvenir dans TensorFlow. Ce problème sur le TensorFlow GitHub est également lié à ma question: https://github.com/tensorflow/tensorflow/issues/2838

Quelqu'un a-t-il de bonnes suggestions pour construire un modèle LSTM dynamique?

inputs  = tf.placeholder(tf.float32, shape=[None, seq_length, 84, 84], name="inputs")
targets = tf.placeholder(tf.float32, shape=[None, seq_length], name="targets")

num_lstm_layers = 2

with tf.variable_scope("LSTM") as scope:

    lstm_cell  = tf.nn.rnn_cell.LSTMCell(512, initializer=initializer, state_is_Tuple=True)
    self.lstm  = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_lstm_layers, state_is_Tuple=True)

    init_c = # compute initial LSTM memory state using contents in placeholder 'inputs'
    init_h = # compute initial LSTM hidden state using contents in placeholder 'inputs'
    self.state = [tf.nn.rnn_cell.LSTMStateTuple(init_c, init_h)] * num_lstm_layers

    outputs = []

    for step in range(seq_length):

        if step != 0:
            scope.reuse_variables()

        # CNN features, as input for LSTM
        x_t = # ... 

        # LSTM step through time
        output, self.state = self.lstm(x_t, self.state)
        outputs.append(output)
23
verified.human

J'ai découvert qu'il était plus facile d'enregistrer l'état entier pour toutes les couches dans un espace réservé.

init_state = np.zeros((num_layers, 2, batch_size, state_size))

...

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])

Décompressez-le ensuite et créez un tuple de LSTMStateTuples avant d'utiliser le tensorflow natif RNN Api.

l = tf.unpack(state_placeholder, axis=0)
rnn_Tuple_state = Tuple(
[tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1])
 for idx in range(num_layers)]
)

RNN passe dans l'API:

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_Tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_Tuple=True)
outputs, state = tf.nn.dynamic_rnn(cell, x_input_batch, initial_state=rnn_Tuple_state)

La variable state - sera ensuite transmise au lot suivant en tant qu'espace réservé.

21
user1506145

Tensorflow, le meilleur moyen de sauvegarder l'état des RNN? était en fait ma question initiale. Le code ci-dessous est la façon dont j'utilise les tuples d'état.

with tf.variable_scope('decoder') as scope:
    rnn_cell = tf.nn.rnn_cell.MultiRNNCell \
    ([
        tf.nn.rnn_cell.LSTMCell(512, num_proj = 256, state_is_Tuple = True),
        tf.nn.rnn_cell.LSTMCell(512, num_proj = Word_VEC_SIZE, state_is_Tuple = True)
    ], state_is_Tuple = True)

    state = [[tf.zeros((BATCH_SIZE, sz)) for sz in sz_outer] for sz_outer in rnn_cell.state_size]

    for t in range(TIME_STEPS):
        if t:
            last = y_[t - 1] if TRAINING else y[t - 1]
        else:
            last = tf.zeros((BATCH_SIZE, Word_VEC_SIZE))

        y[t] = tf.concat(1, (y[t], last))
        y[t], state = rnn_cell(y[t], state)

        scope.reuse_variables()

Plutôt que d'utiliser tf.nn.rnn_cell.LSTMStateTuple Je viens de créer une liste de listes qui fonctionne bien. Dans cet exemple, je ne sauvegarde pas l'état. Cependant, vous auriez pu facilement créer un état à partir de variables et simplement utiliser assign pour enregistrer les valeurs.

6
chasep255