Le didacticiel de pytorch pour le chargement et le traitement des données est assez spécifique à un exemple. Quelqu'un pourrait-il m'aider à comprendre à quoi la fonction devrait ressembler pour un chargement simple plus générique d'images?
Tutoriel: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
Mes données:
J'ai le jeu de données MINST en jpg dans la structure de dossiers suivante. (Je sais que je peux simplement utiliser la classe dataset, mais ceci est simplement pour voir comment charger des images simples dans pytorch sans csv ni fonctionnalités complexes).
Le nom du dossier est l'étiquette et les images sont en png 28x28 en niveaux de gris, aucune transformation requise.
data
train
0
3.png
5.png
13.png
23.png
...
1
3.png
10.png
11.png
...
2
4.png
13.png
...
3
8.png
...
4
...
5
...
6
...
7
...
8
...
9
...
Voici ce que j'ai fait pour pytorch 4.1
def load_dataset():
data_path = 'data/train/'
train_dataset = torchvision.datasets.ImageFolder(
root=data_path,
transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=64,
num_workers=0,
shuffle=True
)
return train_loader
for batch_idx, (data, target) in enumerate(load_dataset()):
#train network
Si vous utilisez mnist, il existe déjà un préréglage dans pytorch via torchvision.
Vous pourriez faire
import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
shuffle=True, num_workers=2)
Si vous souhaitez généraliser à un répertoire d'images (mêmes importations que ci-dessus), vous pouvez faire
class mnistmTrainingDataset(torch.utils.data.Dataset):
def __init__(self,text_file,root_dir,transform=transformMnistm):
"""
Args:
text_file(string): path to text file
root_dir(string): directory with all train images
"""
self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.name_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
image = Image.open(img_name)
image = self.transform(image)
labels = self.label_frame.iloc[idx, 0]
#labels = labels.reshape(-1, 2)
sample = {'image': image, 'labels': labels}
return sample
mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
root_dir = 'Downloads/mnist_m/mnist_m_train')
mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)
Vous pouvez alors le parcourir comme ceci:
for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
print("training sample for mnist-m")
print(i_batch,sample_batched['image'],sample_batched['labels'])
Il y a beaucoup de façons de généraliser pytorch pour le chargement de jeu de données d'image, la méthode que je connais est le sous-classement torch.utils.data.dataset