Je souhaite afficher une seule image. Il a été chargé à l'aide d'un ImageLoader
et est stocké dans un PyTorch Tensor
.
Lorsque j'essaie de l'afficher via plt.imshow(image)
, j'obtiens:
TypeError: Invalid dimensions for image data
Le .shape
du tenseur est:
torch.Size([3, 244, 244])
Comment afficher l'image contenue dans ce tenseur PyTorch?
Étant donné un Tensor
représentant l'image, utilisez .permute()
:
plt.imshow( tensor_image.permute(1, 2, 0) )
Remarque: permute
ne copie ni n'alloue pas de mémoire , et from_numpy()
non plus.
Comme vous pouvez le voir, matplotlib
fonctionne bien même sans conversion en tableau numpy
. Mais les tenseurs PyTorch ("tenseurs d'image") sont les canaux en premier, donc pour les utiliser avec matplotlib
vous devez le remodeler:
Code:
from scipy.misc import face
import matplotlib.pyplot as plt
import torch
np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)
# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)
# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)
plt.imshow(tensor_image)
plt.show()
Production:
<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
Un exemple complet donné un chemin d'accès d'image img_path
:
from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")
Notez que transforms.*
retourne une fonction, c'est pourquoi le bracketing funky.
Étant donné que l'image est chargée comme décrit et stockée dans la variable image
:
plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")
Le matplotlib
tutoriel image dit:
L'interpolation bicubique est souvent utilisée lors de l'explosion de photos - les gens ont tendance à préférer le flou aux pixels.
Ou comme Soumith a suggéré :
%matplotlib inline
def show(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
Ou, pour ouvrir l'image dans une fenêtre contextuelle:
transforms.ToPILImage()(image).show()