#python #neural-network #pytorch
#python #нейронная сеть #pytorch
Вопрос:
Я пытаюсь запустить свою модель на выбранных точках данных в пакете данных, но модель генерирует совершенно другие результаты, чем при непосредственном запуске в пакете.
Вот пример:
у меня есть пакет данных ( firstData
) с формой torch.Size([128, 1, 28, 28])
, что означает, что он содержит 128 точек данных.
firstData
извлекается:
for _, (data, _) in enumerate(train_loader):
firstData = data
break
train_loader
является torch.utils.data.DataLoader
ли набор данных в MNIST
Затем я запускаю:
with torch.no_grad():
print(cnn(firstData.cuda())[0:2]) # run on batch and get first 2 outputs
и он выводит:
tensor([[-0.5045, -0.8611, -1.0237, -2.2146, 1.6829, -0.1202, 8.2230, -2.9030,
-0.6736, -1.6260],
[-1.6367, -0.0683, 2.3553, -2.8480, 8.0057, -2.7570, -0.7046, -1.3720,
-1.1558, 0.2683]], device='cuda:0')
Когда я запускаю:
with torch.no_grad():
print(cnn(firstData.cuda()[0:2])) # select first 2 data points and run the model
и он выводит:
tensor([[ 0.8256, -1.2180, -2.2164, -0.3754, -1.9170, 2.4913, 4.2119, -2.6611,
1.1260, -0.2956],
[-1.1582, 0.8964, 3.2793, -0.4405, 2.4863, -2.9732, -3.5990, 1.8273,
-0.6883, 0.3878]], device='cuda:0')
Кроме того, cnn
это структура, подобная resnet18, и она уже обучена с точностью до 0,99 для набора данных MNIST. Я также сделал cnn
детерминированный, и он всегда возвращает одни и те же результаты при одинаковых входных данных.
Я новичок в Pytorch, и я предполагаю, что это потому, что я неправильно обращаюсь к пакетным данным. Кто-нибудь может дать несколько предложений по решению этой проблемы?
Обновить
Спасибо @kHarshit! Вызов cnn.eval()
решил мою проблему.
В руководстве по PyTorch здесь:
«Помните, что вы должны вызвать model.eval()
, чтобы перевести уровни отсева и пакетной нормализации в режим оценки, прежде чем запускать вывод. Невыполнение этого требования приведет к противоречивым результатам вывода. Если вы хотите возобновить обучение, вызовите model.train()
, чтобы убедиться, что эти слои находятся в режиме обучения «.
Это имеет смысл, потому что my cnn
содержит BatchNorm2d
слои.
Официальный API можно найти здесь
Комментарии:
1. Вы установили
cnn.eval()
?2. Нет, я его не устанавливал
3. установлен ли shuffle в загрузчике данных в true?
4. вам необходимо перевести модель в режим оценки.