web-dev-qa-db-fra.com

Comment afficher une seule image dans PyTorch?

Je souhaite afficher une seule image. Il a été chargé à l'aide d'un ImageLoader et est stocké dans un PyTorch Tensor.

Lorsque j'essaie de l'afficher via plt.imshow(image), j'obtiens:

TypeError: Invalid dimensions for image data

Le .shape du tenseur est:

torch.Size([3, 244, 244])

Comment afficher l'image contenue dans ce tenseur PyTorch?

11
Tom Hale

Étant donné un Tensor représentant l'image, utilisez .permute() :

plt.imshow(  tensor_image.permute(1, 2, 0)  )

Remarque: permute ne copie ni n'alloue pas de mémoire , et from_numpy() non plus.

23
Tom Hale

Comme vous pouvez le voir, matplotlib fonctionne bien même sans conversion en tableau numpy. Mais les tenseurs PyTorch ("tenseurs d'image") sont les canaux en premier, donc pour les utiliser avec matplotlib vous devez le remodeler:

Code:

from scipy.misc import face
import matplotlib.pyplot as plt
import torch

np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)

# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)

# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)

plt.imshow(tensor_image)
plt.show()

Production:

<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
8
trsvchn

Un exemple complet donné un chemin d'accès d'image img_path:

from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")

Notez que transforms.* retourne une fonction, c'est pourquoi le bracketing funky.

1
Tom Hale

Étant donné que l'image est chargée comme décrit et stockée dans la variable image:

plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")

Le matplotlib tutoriel image dit:

L'interpolation bicubique est souvent utilisée lors de l'explosion de photos - les gens ont tendance à préférer le flou aux pixels.


Ou comme Soumith a suggéré :

%matplotlib inline
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')

Ou, pour ouvrir l'image dans une fenêtre contextuelle:

 transforms.ToPILImage()(image).show()
0
Tom Hale