Ошибка времени выполнения: Заданные группы=1, вес размера [64, 1, 19], ожидаемый ввод[64, 187, 1] должен иметь 1 канал, но вместо этого получил 187 каналов(классификатор CNN)

#python #tensorflow #conv-neural-network

Вопрос:

Я внедряю классификатор CNN с этой структурой:

 (conv1): Conv1d(cin = 1, cout = 64, kernel size = 19)

(fc1): Linear(in_features=64*187, out_features=32, bias=True)

(relu): ReLU()

(fc2): Linear(in_features=32, out_features=5, bias=True)
 

Мой код таков:
классификатор классов(nn.Модуль):

 def __init__(self):
    super(ConvClassifier, self).__init__()
    self.conv1 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=19)
    self.fc1 = nn.Linear(in_features=64*187, out_features=32, bias=True)
    self.relu = nn.ReLU()
    self.fc2 = nn.Linear(in_features=32, out_features=5, bias=True)
    

def forward(self, x):
    x = x.unsqueeze(-1)
    
    out = self.conv1(x)
    out = self.fc1(out)
    out = self.relu(out)
    out = self.fc2(out)
    return out         

conv_classifier = ConvClassifier()
conv_classifier.to(device)

criterion = nn.CrossEntropyLoss(weight = get_loss_weights(class_samples_num))
criterion.to(device)
optimizer = torch.optim.Adam(conv_classifier.parameters())

data , label = next(iter(train_loader))
output = conv_classifier(data.to(device))
 

Ошибка заключается в:

Трассировка ошибки выполнения (последний последний вызов) в ()

     37 
     38 data , label = next(iter(train_loader))
---> 39 output = conv_classifier(data.to(device))
     40 

4 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    293                             _single(0), self.dilation, self.groups)
    294         return F.conv1d(input, weight, bias, self.stride,
--> 295                         self.padding, self.dilation, self.groups)
    296 
    297     def forward(self, input: Tensor) -> Tensor:
 

Ошибка времени выполнения: Заданные группы=1, вес размера [64, 1, 19], ожидаемый ввод[64, 187, 1] должен иметь 1 канал, но вместо этого получил 187 каналов

Как это исправить?