Разделение сэмплера CIFAR10 dataloader

#python #numpy #machine-learning #pytorch

#python #numpy #машинное обучение #pytorch

Вопрос:

я пытаюсь разделить обучающие данные CIFAR10, чтобы для проверки использовались последние 5000 обучающих наборов. мой код

 size = len(CIFAR10_training)
dataset_indices = list(range(size))
val_index = int(np.floor(0.9 * size))
train_idx, val_idx = dataset_indices[:val_index], dataset_indices[val_index:]
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)

train_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
                                          batch_size=config['batch_size'],
                                          shuffle=False,  sampler = train_sampler)
valid_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
                                           batch_size=config['batch_size'],
                                           shuffle=False,  sampler = val_sampler)
print(len(train_dataloader.dataset),len(valid_dataloader.dataset),
  

но последний оператор печати печатает 50000 и 10000. не должно ли это быть 45000 и 5000
когда я печатаю train_idx и val_idx, он выводит правильные значения([0:44999],[45000:49999]
что-то не так с моим кодом

Ответ №1:

Я не могу воспроизвести ваши результаты, когда я выполняю ваш код, операторы печати выводят в два раза одно и то же число: количество элементов в train_CIFAR10 . Итак, я предполагаю, что вы допустили ошибку при копировании своего кода и valid_dataloader фактически задается CIFAR10_test (или что-то в этом роде) в качестве параметра. В дальнейшем я собираюсь предположить, что это так, и что ваши выходные данные для печати (50000, 50000) , которые являются размером обучающей части набора данных Pytorch CIFAR10.

Тогда это полностью ожидаемо, и нет, он не должен выводиться (45000, 5000). Вы запрашиваете длину train_dataloader.dataset и valid_dataloader.dataset , т.Е. Длину базовых наборов данных. Для обоих ваших загрузчиков этот набор CIFAR10_training данных . Таким образом, вы получите вдвое больший размер этого набора данных (т.Е. 50000).

Вы не можете запрашивать len(train_dataloader) ни то, ни другое, потому что это дало бы количество пакетов в вашем наборе данных (приблизительно 45000/batch_size ).

Если вам нужно знать размер ваших разделений, тогда вам нужно вычислить длину ваших сэмплеров:

 print(len(train_dataloader.sampler), len(valid_dataloader.sampler))
  

Кроме того, ваш код в порядке, вы правильно разделяете свои данные.

Комментарии:

1. хорошо, я понимаю, что вы имеете в виду, и что я делаю неправильно, спасибо @trialNerror

2. Здравствуйте, хотя я ценю ваше мнение, пожалуйста, найдите время, чтобы принять ответ, если он вас удовлетворяет, или задайте любой дополнительный вопрос, который, по вашему мнению, остался без ответа