J'essaie d'implémenter un CNN pour identifier les chiffres dans l'ensemble de données MNIST et mon code présente l'erreur pendant le processus de chargement des données. Je ne comprends pas pourquoi cela se produit.
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=20, shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=20, shuffle=False, num_workers=2)
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0], data[1]
Erreur:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-6-b37c638b6114> in <module>
2
----> 3 for i, data in enumerate(trainloader, 0):
4 inputs, labels = data[0], data[1]
# ...
IndexError: Traceback (most recent call last):
File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "/opt/conda/lib/python3.6/site-packages/torchvision/datasets/mnist.py", line 95, in __getitem__
img = self.transform(img)
File "/opt/conda/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 61, in __call__
img = t(img)
File "/opt/conda/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 164, in __call__
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "/opt/conda/lib/python3.6/site-packages/torchvision/transforms/functional.py", line 208, in normalize
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
IndexError: too many indices for tensor of dimension 0
vérifier si la rame n'est pas vide, simple impression, idem pour trainloader, si ça ne marche toujours pas je préfère charger manuellement la mnist avec
def load_mnist_labels(fnlabel):
f = gzip.open(fnlabel, 'rb')
f.read(8)
return np.frombuffer(f.read(), dtype = np.uint8)
def load_mnist_images(fnlabel):
f = gzip.open(fnlabel, 'rb')
f.read(16)
return np.frombuffer(f.read(), dtype = np.uint8)