Il s'agit des deux méthodes de création d'un modèle de kéros, mais le output shapes
des résultats sommaires des deux méthodes sont différents. De toute évidence, le premier imprime plus d'informations et facilite la vérification de l'exactitude du réseau.
import tensorflow as tf
from tensorflow.keras import Input, layers, Model
class subclass(Model):
def __init__(self):
super(subclass, self).__init__()
self.conv = layers.Conv2D(28, 3, strides=1)
def call(self, x):
return self.conv(x)
def func_api():
x = Input(shape=(24, 24, 3))
y = layers.Conv2D(28, 3, strides=1)(x)
return Model(inputs=[x], outputs=[y])
if __name__ == '__main__':
func = func_api()
func.summary()
sub = subclass()
sub.build(input_shape=(None, 24, 24, 3))
sub.summary()
production:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 24, 24, 3) 0
_________________________________________________________________
conv2d (Conv2D) (None, 22, 22, 28) 784
=================================================================
Total params: 784
Trainable params: 784
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) multiple 784
=================================================================
Total params: 784
Trainable params: 784
Non-trainable params: 0
_________________________________________________________________
Alors, comment dois-je utiliser la méthode de sous-classe pour obtenir le output shape
au résumé ()?
J'ai utilisé cette méthode pour résoudre ce problème, je ne sais pas s'il existe un moyen plus simple.
class subclass(Model):
def __init__(self):
...
def call(self, x):
...
def model(self):
x = Input(shape=(24, 24, 3))
return Model(inputs=[x], outputs=self.call(x))
if __name__ == '__main__':
sub = subclass()
sub.model().summary()