web-dev-qa-db-fra.com

TypeError: le tenseur n'est pas une image de torche

En travaillant sur le cours d'IA à Udacity, je suis tombé sur cette erreur lors de la section Transfert d'apprentissage. Voici le code qui semble être à l'origine du problème:

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

data_dir = 'filename'

# TODO: Define transforms for the training data and testing data
train_transforms= transforms.Compose([transforms.Resize((224,224)), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.ToTensor()])
test_transforms= transforms.Compose([transforms.Resize((224,224)), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.ToTensor()])

# Pass transforms in here, then run the next cell to see how the transforms look
train_data = datasets.ImageFolder(data_dir + '/train', transform=train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test', transform=test_transforms)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
8
Tristan Newman

Le problème est avec l'ordre des transformations. La transformation ToTensor doit précéder la transformation Normalize, car cette dernière attend un tenseur, mais la transformation Resize renvoie une image. Code correct avec les lignes défectueuses modifiées:

train_transforms = transforms.Compose([
    transforms.Resize((224,224)), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
test_transforms = transforms.Compose([
    transforms.Resize((224,224)), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
18
Agost Biro

Une autre solution, moins élégante (en supposant que l'image a été chargée avec opencv et est donc BGR):

t_ = transforms.Compose([transforms.ToPILImage(),
                         transforms.Resize((224,224)),
                         transforms.ToTensor()])

norm_ = transforms.Normalize([103.939, 116.779, 123.68],[1,1,1])
img = 255*t_(img)
img = norm_(img)
1
Alex