#pytorch #gated-recurrent-unit
#pytorch #gated-recurrent-unit
Вопрос:
Я работаю над GRU, и когда я пытаюсь сделать прогнозы, я получаю сообщение об ошибке, указывающее, что мне нужно определить h для forward() . Я попробовал несколько вещей, и у меня кончилось терпение после поиска в Google и поиска переполнения стека в течение нескольких часов.
Это класс:
class GRUNet(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, n_layers, drop_prob = 0.2):
super(GRUNet, self).__init__()
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.gru = nn.GRU(input_dim, hidden_dim, n_layers, batch_first=True, dropout=drop_prob)
self.fc = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
def forward(self, x, h):
out, h = self.gru(x,h)
out = self.fc(self.relu(out[:,-1]))
return out, h
def init_hidden(self, batch_size):
weight = next(self.parameters()).data
hidden = weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device)
return hidden
и затем здесь я загружаю модель и пытаюсь сделать прогноз. Оба они находятся в одном скрипте.
inputs = np.load('.//Pred//input_list.npy')
print(inputs.ndim, inputs.shape)
Gmodel = GRUNet(24,256,1,2)
Gmodel = torch.load('.//GRU//GRU_1028_48.pkl')
Gmodel.eval()
pred = Gmodel(inputs)
Без каких-либо других аргументов для Gmodel я получаю следующее:
Traceback (most recent call last):
File ".grunet.py", line 136, in <module>
pred = Gmodel(inputs)
File "C:UsersryangAnaconda-3envstf-gpulibsite-packagestorchnnmodulesmodule.py", line 547, in __call__
result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'h'
Ответ №1:
Вам также необходимо указать скрытое состояние, которое обычно изначально состоит из нулей или просто None
!
То есть вам либо нужно явно указать такой :
hidden_state = torch.zeros(size=(num_layers*direction, batch_size, hidden_dim)).to(device)
pred = Gmodel(inputs, hidden_state)
или просто сделайте :
hidden_state = None
pred = Gmodel(inputs, hidden_state)
Комментарии:
1. Это еще один вопрос, который вам нужно опубликовать отдельно. если этот ответ решает вашу первоначальную проблему, пожалуйста, примите его как ответ, чтобы этот вопрос был выполнен. затем задайте свой новый вопрос отдельно, и мы постараемся ответить на него как можно лучше.