Как устранить ошибку «фигуры не могут быть умножены»

#pytorch

#pytorch

Вопрос:

Я попробовал код, упомянутый в статье в Google colab.

https://theaisummer.com/spiking-neural-networks/

Я получил ошибку, которая выглядит следующим образом…

 Test loss:  8.86368179321289
Test loss:  5.338221073150635
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-646cb112ccb7> in <module>()
     15         # forward pass
     16         net.train()
---> 17         spk_rec, mem_rec = net(data.view(batch_size, -1))
     18 
     19         # initialize the loss amp; sum over time

4 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
   1846     if has_torch_function_variadic(input, weight, bias):
   1847         return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848     return torch._C._nn.linear(input, weight, bias)
   1849 
   1850 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x588 and 784x1000)
 

Я не уверен, как это исправить.

Комментарии:

1. Можете ли вы показать свое определение модели?

Ответ №1:

Я только что запустил их записную книжку Colab и столкнулся с той же ошибкой. Это происходит потому, что на последней итерации нет 128 выборок данных, поскольку общий размер набора данных (60000 и 10000 для обучающего и тестового набора) неравномерно делится на 128. Так что кое-что осталось, и изменим его размер до 128 x… приводит к несоответствию размеров между входными данными и количеством нейронов во входном слое.

Есть два возможных решения.

  1. Просто удалите последнюю партию:
 train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)
 
  1. Не удаляйте последнюю партию. Но сгладьте тензор таким образом, чтобы сохранить исходный batch_size, вместо того, чтобы увеличивать его до 128:

spk_rec, mem_rec = net(data.flatten(1))