Как я могу использовать torch Dataloader для получения изображения с тем же классом?

#python #pytorch #torch #torchvision #pytorch-dataloader

#python #pytorch #факел #torchvision #pytorch-загрузчик данных

Вопрос:

в моем наборе данных есть 6 классов и 23 изображения на класс
Раньше я torchvision.dataset делал ImageFolder , и это работало хорошо.

 dataset = vision_dataset.ImageFolder(root = DATA_ROOT,
                                     transform = vision_trans.Compose([
                                                    vision_trans.Resize(256),
                                                    vision_trans.CenterCrop(256),
                                                    vision_trans.ToTensor()
                                     ]))

dataloader = torch.utils.data.DataLoader(dataset = dataset, batch_size = SHOT_K,
                                         shuffle = False, num_workers = 2, )
 

но я хочу получать пакетные изображения с тем же классом.

 ...
tensor([2, 2, 2, 2, 2])
tensor([2, 2])
tensor([3, 3, 3, 3, 3])
...
 

Это то, что я хотел, форма label (класс пакетных данных)
но на самом деле загрузчик данных будет работать так

 ...
tensor([2, 2, 2, 2, 2])
tensor([2, 2, 3, 3, 3])
tensor([3, 3, 3, 3, 3])
...
 

как я могу получить пакетные данные для каждой метки?

Ответ №1:

Вы не можете удобно сделать это с ImageFolder помощью . Вы должны создать один набор данных для каждого класса и загружать свои пакеты из нужного вам набора данных.

Более конкретно, предполагая, что ваша структура папок соответствует структуре, требуемой ImageFolder, вам нужно создать один небольшой класс dataset :

 class ImageSubFolder(torch.utils.data.Dataset):
    def __init__(self, root_dir, label):
        # Path toward the label-sorted subfolders of your dataset
        # Assuming images are named smthg like /path/to/label/xxxx.npy
        self._path = root_dir   label  "{:04d}"

    def __len__(self):
        return count_files_in_directory(self._path)

    def __getitem__(self, index):
        return (np.load(self._path.format(index), label)
 

Это просто для того, чтобы показать логику класса, я полагаю, у вас все равно будет несколько функций для реализации (вы можете следовать этому руководству). «Остальные функции для реализации оставлены в качестве упражнения для читателя». В любом случае с этим классом вам просто нужно создать 6 его экземпляров (по одному на класс) :

 loaders = {}
for label in ("dog", "cat", "plane", "tree", "mug", "car"):
    dataset = SubFolderDataset(DATA_ROOT, label)
    loaders[label] = torch.utils.data.DataLoader(dataset = dataset, batch_size = SHOT_K,shuffle = False, num_workers = 2, )
 

Теперь у вас есть dict, который содержит загрузчики данных, которые загружают только образцы данного класса.