#python #numpy #machine-learning #pytorch
#python #numpy #машинное обучение #pytorch
Вопрос:
я пытаюсь разделить обучающие данные CIFAR10, чтобы для проверки использовались последние 5000 обучающих наборов. мой код
size = len(CIFAR10_training)
dataset_indices = list(range(size))
val_index = int(np.floor(0.9 * size))
train_idx, val_idx = dataset_indices[:val_index], dataset_indices[val_index:]
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
train_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
batch_size=config['batch_size'],
shuffle=False, sampler = train_sampler)
valid_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
batch_size=config['batch_size'],
shuffle=False, sampler = val_sampler)
print(len(train_dataloader.dataset),len(valid_dataloader.dataset),
но последний оператор печати печатает 50000 и 10000. не должно ли это быть 45000 и 5000
когда я печатаю train_idx и val_idx, он выводит правильные значения([0:44999],[45000:49999]
что-то не так с моим кодом
Ответ №1:
Я не могу воспроизвести ваши результаты, когда я выполняю ваш код, операторы печати выводят в два раза одно и то же число: количество элементов в train_CIFAR10
. Итак, я предполагаю, что вы допустили ошибку при копировании своего кода и valid_dataloader
фактически задается CIFAR10_test
(или что-то в этом роде) в качестве параметра. В дальнейшем я собираюсь предположить, что это так, и что ваши выходные данные для печати (50000, 50000)
, которые являются размером обучающей части набора данных Pytorch CIFAR10.
Тогда это полностью ожидаемо, и нет, он не должен выводиться (45000, 5000). Вы запрашиваете длину train_dataloader.dataset
и valid_dataloader.dataset
, т.Е. Длину базовых наборов данных. Для обоих ваших загрузчиков этот набор CIFAR10_training
данных . Таким образом, вы получите вдвое больший размер этого набора данных (т.Е. 50000).
Вы не можете запрашивать len(train_dataloader)
ни то, ни другое, потому что это дало бы количество пакетов в вашем наборе данных (приблизительно 45000/batch_size
).
Если вам нужно знать размер ваших разделений, тогда вам нужно вычислить длину ваших сэмплеров:
print(len(train_dataloader.sampler), len(valid_dataloader.sampler))
Кроме того, ваш код в порядке, вы правильно разделяете свои данные.
Комментарии:
1. хорошо, я понимаю, что вы имеете в виду, и что я делаю неправильно, спасибо @trialNerror
2. Здравствуйте, хотя я ценю ваше мнение, пожалуйста, найдите время, чтобы принять ответ, если он вас удовлетворяет, или задайте любой дополнительный вопрос, который, по вашему мнению, остался без ответа