#python #pytorch #torch #torchvision #pytorch-dataloader
#python #pytorch #факел #torchvision #pytorch-загрузчик данных
Вопрос:
в моем наборе данных есть 6 классов и 23 изображения на класс
Раньше я torchvision.dataset
делал ImageFolder
, и это работало хорошо.
dataset = vision_dataset.ImageFolder(root = DATA_ROOT,
transform = vision_trans.Compose([
vision_trans.Resize(256),
vision_trans.CenterCrop(256),
vision_trans.ToTensor()
]))
dataloader = torch.utils.data.DataLoader(dataset = dataset, batch_size = SHOT_K,
shuffle = False, num_workers = 2, )
но я хочу получать пакетные изображения с тем же классом.
...
tensor([2, 2, 2, 2, 2])
tensor([2, 2])
tensor([3, 3, 3, 3, 3])
...
Это то, что я хотел, форма label (класс пакетных данных)
но на самом деле загрузчик данных будет работать так
...
tensor([2, 2, 2, 2, 2])
tensor([2, 2, 3, 3, 3])
tensor([3, 3, 3, 3, 3])
...
как я могу получить пакетные данные для каждой метки?
Ответ №1:
Вы не можете удобно сделать это с ImageFolder
помощью . Вы должны создать один набор данных для каждого класса и загружать свои пакеты из нужного вам набора данных.
Более конкретно, предполагая, что ваша структура папок соответствует структуре, требуемой ImageFolder, вам нужно создать один небольшой класс dataset :
class ImageSubFolder(torch.utils.data.Dataset):
def __init__(self, root_dir, label):
# Path toward the label-sorted subfolders of your dataset
# Assuming images are named smthg like /path/to/label/xxxx.npy
self._path = root_dir label "{:04d}"
def __len__(self):
return count_files_in_directory(self._path)
def __getitem__(self, index):
return (np.load(self._path.format(index), label)
Это просто для того, чтобы показать логику класса, я полагаю, у вас все равно будет несколько функций для реализации (вы можете следовать этому руководству). «Остальные функции для реализации оставлены в качестве упражнения для читателя». В любом случае с этим классом вам просто нужно создать 6 его экземпляров (по одному на класс) :
loaders = {}
for label in ("dog", "cat", "plane", "tree", "mug", "car"):
dataset = SubFolderDataset(DATA_ROOT, label)
loaders[label] = torch.utils.data.DataLoader(dataset = dataset, batch_size = SHOT_K,shuffle = False, num_workers = 2, )
Теперь у вас есть dict, который содержит загрузчики данных, которые загружают только образцы данного класса.