web-dev-qa-db-fra.com

Comment charger des images dans Pytorch DataLoader?

Le didacticiel de pytorch pour le chargement et le traitement des données est assez spécifique à un exemple. Quelqu'un pourrait-il m'aider à comprendre à quoi la fonction devrait ressembler pour un chargement simple plus générique d'images?

Tutoriel: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

Mes données:

J'ai le jeu de données MINST en jpg dans la structure de dossiers suivante. (Je sais que je peux simplement utiliser la classe dataset, mais ceci est simplement pour voir comment charger des images simples dans pytorch sans csv ni fonctionnalités complexes).

Le nom du dossier est l'étiquette et les images sont en png 28x28 en niveaux de gris, aucune transformation requise.

data
    train
        0
            3.png
            5.png
            13.png
            23.png
            ...
        1
            3.png
            10.png
            11.png
            ...
        2
            4.png
            13.png
            ...
        3
            8.png
            ...
        4
            ...
        5
            ...
        6
            ...
        7
            ...
        8
            ...
        9
            ...
12
Terry

Voici ce que j'ai fait pour pytorch 4.1

def load_dataset():
    data_path = 'data/train/'
    train_dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=64,
        num_workers=0,
        shuffle=True
    )
    return train_loader

for batch_idx, (data, target) in enumerate(load_dataset()):
    #train network
16
Duane

Si vous utilisez mnist, il existe déjà un préréglage dans pytorch via torchvision.
Vous pourriez faire

import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd

transform = transforms.Compose(
[transforms.ToTensor(),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
                                    download=True, transform=transform)
mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
                                      shuffle=True, num_workers=2)

Si vous souhaitez généraliser à un répertoire d'images (mêmes importations que ci-dessus), vous pouvez faire

class mnistmTrainingDataset(torch.utils.data.Dataset):

    def __init__(self,text_file,root_dir,transform=transformMnistm):
        """
        Args:
            text_file(string): path to text file
            root_dir(string): directory with all train images
        """
        self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
        self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.name_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
        image = Image.open(img_name)
        image = self.transform(image)
        labels = self.label_frame.iloc[idx, 0]
        #labels = labels.reshape(-1, 2)
        sample = {'image': image, 'labels': labels}

        return sample


mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
                                   root_dir = 'Downloads/mnist_m/mnist_m_train')

mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)

Vous pouvez alors le parcourir comme ceci:

for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
    print("training sample for mnist-m")
    print(i_batch,sample_batched['image'],sample_batched['labels'])

Il y a beaucoup de façons de généraliser pytorch pour le chargement de jeu de données d'image, la méthode que je connais est le sous-classement torch.utils.data.dataset

9
Ari K