J'essaie d'importer des poids enregistrés à partir d'un modèle Tensorflow vers PyTorch. Jusqu'à présent, les résultats ont été très similaires. Je suis tombé sur un hic quand le modèle appelle conv2d
avec stride=2
.
Pour vérifier le décalage, j'ai mis en place une comparaison très simple entre TF et PyTorch. Tout d'abord, je compare conv2d
avec stride=1
.
import tensorflow as tf
import numpy as np
import torch
import torch.nn.functional as F
np.random.seed(0)
sess = tf.Session()
# Create random weights and input
weights = torch.empty(3, 3, 3, 8)
torch.nn.init.constant_(weights, 5e-2)
x = np.random.randn(1, 3, 10, 10)
weights_tf = tf.convert_to_tensor(weights.numpy(), dtype=tf.float32)
# PyTorch adopts [outputC, inputC, kH, kW]
weights_torch = torch.Tensor(weights.permute((3, 2, 0, 1)))
# Tensorflow defaults to NHWC
x_tf = tf.convert_to_tensor(x.transpose((0, 2, 3, 1)), dtype=tf.float32)
x_torch = torch.Tensor(x)
# TF Conv2D
tf_conv2d = tf.nn.conv2d(x_tf,
weights_tf,
strides=[1, 1, 1, 1],
padding="SAME")
# PyTorch Conv2D
torch_conv2d = F.conv2d(x_torch, weights_torch, padding=1, stride=1)
sess.run(tf.global_variables_initializer())
tf_result = sess.run(tf_conv2d)
diff = np.mean(np.abs(tf_result.transpose((0, 3, 1, 2)) - torch_conv2d.detach().numpy()))
print('Mean of Abs Diff: {0}'.format(diff))
Le résultat de cette exécution est:
Mean of Abs Diff: 2.0443112092038973e-08
Lorsque je change stride
en 2, les résultats commencent à varier.
# TF Conv2D
tf_conv2d = tf.nn.conv2d(x_tf,
weights_tf,
strides=[1, 2, 2, 1],
padding="SAME")
# PyTorch Conv2D
torch_conv2d = F.conv2d(x_torch, weights_torch, padding=1, stride=2)
Le résultat de cette exécution est:
Mean of Abs Diff: 0.2104552686214447
Selon la documentation de PyTorch, conv2d
utilise un remplissage nul défini par l'argument padding
. Ainsi, des zéros sont ajoutés à gauche, en haut, à droite et en bas de l'entrée dans mon exemple.
Si PyTorch ajoute simplement un rembourrage des deux côtés en fonction du paramètre d'entrée, il devrait être facile de répliquer dans Tensorflow.
# Manually add padding - consistent with PyTorch
paddings = tf.constant([[0, 0], [1, 1], [1, 1], [0, 0]])
x_tf = tf.convert_to_tensor(x.transpose((0, 2, 3, 1)), dtype=tf.float32)
x_tf = tf.pad(x_tf, paddings, "CONSTANT")
# TF Conv2D
tf_conv2d = tf.nn.conv2d(x_tf,
weights_tf,
strides=[1, 2, 2, 1],
padding="VALID")
Le résultat de cette comparaison est:
Mean of Abs Diff: 1.6035047067930464e-08
Ce que cela me dit, c'est que si je suis en mesure de reproduire le comportement de remplissage par défaut de Tensorflow dans PyTorch, mes résultats seront similaires.
Cette question a inspecté le comportement du remplissage dans Tensorflow. La documentation TF explique comment le remplissage est ajouté pour les "MÊMES" circonvolutions. J'ai découvert ces liens en écrivant cette question.
Maintenant que je connais la stratégie de remplissage de Tensorflow, je peux l'implémenter dans PyTorch.
Pour reproduire le comportement, les tailles de remplissage sont calculées comme décrit dans la documentation Tensorflow. Ici, je teste le comportement de remplissage en définissant stride=2
et le remplissage de l'entrée PyTorch.
import tensorflow as tf
import numpy as np
import torch
import torch.nn.functional as F
np.random.seed(0)
sess = tf.Session()
# Create random weights and input
weights = torch.empty(3, 3, 3, 8)
torch.nn.init.constant_(weights, 5e-2)
x = np.random.randn(1, 3, 10, 10)
weights_tf = tf.convert_to_tensor(weights.numpy(), dtype=tf.float32)
weights_torch = torch.Tensor(weights.permute((3, 2, 0, 1)))
# Tensorflow padding behavior. Assuming that kH == kW to keep this simple.
stride = 2
if x.shape[2] % stride == 0:
pad = max(weights.shape[0] - stride, 0)
else:
pad = max(weights.shape[0] - (x.shape[2] % stride), 0)
if pad % 2 == 0:
pad_val = pad // 2
padding = (pad_val, pad_val, pad_val, pad_val)
else:
pad_val_start = pad // 2
pad_val_end = pad - pad_val_start
padding = (pad_val_start, pad_val_end, pad_val_start, pad_val_end)
x_tf = tf.convert_to_tensor(x.transpose((0, 2, 3, 1)), dtype=tf.float32)
x_torch = torch.Tensor(x)
x_torch = F.pad(x_torch, padding, "constant", 0)
# TF Conv2D
tf_conv2d = tf.nn.conv2d(x_tf,
weights_tf,
strides=[1, stride, stride, 1],
padding="SAME")
# PyTorch Conv2D
torch_conv2d = F.conv2d(x_torch, weights_torch, padding=0, stride=stride)
sess.run(tf.global_variables_initializer())
tf_result = sess.run(tf_conv2d)
diff = np.mean(np.abs(tf_result.transpose((0, 3, 1, 2)) - torch_conv2d.detach().numpy()))
print('Mean of Abs Diff: {0}'.format(diff))
La sortie est:
Mean of Abs Diff: 2.2477470551507395e-08
Je ne savais pas trop pourquoi cela se produisait quand j'ai commencé à écrire cette question, mais un peu de lecture l'a clarifié très rapidement. J'espère que cet exemple pourra aider les autres.