web-dev-qa-db-fra.com

Comment ajouter un mécanisme d'attention dans keras?

J'utilise actuellement ce code que je reçois de une discussion sur github Voici le code du mécanisme d'attention:

_input = Input(shape=[max_length], dtype='int32')

# get the embedding layer
embedded = Embedding(
        input_dim=vocab_size,
        output_dim=embedding_size,
        input_length=max_length,
        trainable=False,
        mask_zero=False
    )(_input)

activations = LSTM(units, return_sequences=True)(embedded)

# compute importance for each step
attention = Dense(1, activation='tanh')(activations)
attention = Flatten()(attention)
attention = Activation('softmax')(attention)
attention = RepeatVector(units)(attention)
attention = Permute([2, 1])(attention)


sent_representation = merge([activations, attention], mode='mul')
sent_representation = Lambda(lambda xin: K.sum(xin, axis=-2), output_shape=(units,))(sent_representation)

probabilities = Dense(3, activation='softmax')(sent_representation)

Est-ce la bonne façon de le faire? Je m'attendais en quelque sorte à l'existence d'une couche temporellement répartie puisque le mécanisme d'attention est distribué à chaque pas temporel du RNN. J'ai besoin de quelqu'un pour confirmer que cette implémentation (le code) est une implémentation correcte du mécanisme d'attention. Je vous remercie.

10
Aryo Pradipta Gema

Si vous voulez attirer l'attention sur la dimension temporelle, cette partie de votre code me semble correcte:

activations = LSTM(units, return_sequences=True)(embedded)

# compute importance for each step
attention = Dense(1, activation='tanh')(activations)
attention = Flatten()(attention)
attention = Activation('softmax')(attention)
attention = RepeatVector(units)(attention)
attention = Permute([2, 1])(attention)

sent_representation = merge([activations, attention], mode='mul')

Vous avez défini le vecteur d’attention de la forme (batch_size, max_length):

attention = Activation('softmax')(attention)

Je n'ai jamais vu ce code auparavant, je ne peux donc pas dire si celui-ci est correct ou non:

K.sum(xin, axis=-2)

Lectures supplémentaires (vous pouvez jeter un coup d'œil):

10
Philippe Remy