Sous les nouvelles modifications de l'API, comment effectuez-vous la multiplication par élément des couches dans Keras? Sous l'ancienne API, j'essaierais quelque chose comme ceci:
merge([dense_all, dense_att], output_shape=10, mode='mul')
J'ai essayé ceci (MWE):
from keras.models import Model
from keras.layers import Input, Dense, Multiply
def sample_model():
model_in = Input(shape=(10,))
dense_all = Dense(10,)(model_in)
dense_att = Dense(10, activation='softmax')(model_in)
att_mull = Multiply([dense_all, dense_att]) #merge([dense_all, dense_att], output_shape=10, mode='mul')
model_out = Dense(10, activation="sigmoid")(att_mull)
return 0
if __== '__main__':
sample_model()
Trace complète:
Using TensorFlow backend.
Traceback (most recent call last):
File "testJan17.py", line 13, in <module>
sample_model()
File "testJan17.py", line 8, in sample_model
att_mull = Multiply([dense_all, dense_att]) #merge([dense_all, dense_att], output_shape=10, mode='mul')
TypeError: __init__() takes exactly 1 argument (2 given)
MODIFIER:
J'ai essayé d'implémenter la fonction de multiplication élémentaire de tensorflow. Bien entendu, le résultat n'est pas une instance Layer()
, donc cela ne fonctionne pas. Voici la tentative, pour la postérité:
def new_multiply(inputs): #assume two only - bad practice, but for illustration...
return tf.multiply(inputs[0], inputs[1])
def sample_model():
model_in = Input(shape=(10,))
dense_all = Dense(10,)(model_in)
dense_att = Dense(10, activation='softmax')(model_in) #which interactions are important?
new_mult = new_multiply([dense_all, dense_att])
model_out = Dense(10, activation="sigmoid")(new_mult)
model = Model(inputs=model_in, outputs=model_out)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
return model
Avec keras
> 2.0:
from keras.layers import multiply
output = multiply([dense_all, dense_att])
Vous devez ajouter une autre parenthèse étroite ouverte à l'avant.
from keras.layers import Multiply
att_mull = Multiply()([dense_all, dense_att])
Sous l'API fonctionnelle, vous utilisez simplement la fonction multiply
, notez le minuscule "m". La classe Multiply est une couche que vous voyez, destinée à être utilisée avec l'API séquentielle.
Plus d'informations dans https://keras.io/layers/merge/#multiply_1