Pytorch nn.Параметр обновляется только для первой эпохи

#python-3.x #neural-network #pytorch #recurrent-neural-network

#python-3.x #нейронная сеть #pytorch #рекуррентная нейронная сеть

Вопрос:

Я пытаюсь реализовать пользовательский модуль (это упрощенная версия), и я проверяю переменную self.param во время каждой итерации, значение не меняется после первой итерации, даже если градиент имеет значение. Кто-нибудь, знакомый с pytorch, имеет представление, почему это происходит?

 class Custom_RNN(nn.Module):
    def __init__(self):
        super(Custom_RNN, self).__init__()
        self.hl1 = nn.Linear(2, 1024)
        self.hl2 = nn.Linear(1024, 512)
        self.out = nn.Linear(512, 1)
        self.param = nn.Parameter(torch.Tensor([250]), requires_grad=True)

    def forward(self, state_init, input, output):

        combined = torch.cat((state_init, input[0]), 1)
        out = torch.sigmoid(self.hl1(combined))
        out = torch.sigmoid(self.hl2(out))
        out = self.out(out)

        for t in range(input.shape[1]):

            # Predict SoC
            shifted_out = out   input[t]/self.param

            combined = torch.cat((shifted_out, input[t]), 1)
            out = torch.sigmoid(self.hl1(combined))
            out = torch.sigmoid(self.hl2(out))
            out = self.out(out)
            if first:
                loss = torch.pow(out - output[t], 2.0)
            else:
                loss = loss   torch.pow(out - output[t], 2.0)

        return loss

rnn = Custom_RNN()
optimiser = optim.Adam(rnn.parameters())

for epoch in range(epochs):
    optimiser.zero_grad()
    loss = rnn(initial_cond, data_in, data_out)
    rnn_param_before = rnn.param.item()
    loss.backward()
    optimiser.step()
    rnn_param_after = rnn.param.item()
    print(rnn_param_before - rnn_param_after)
    print(rnn.param.item(), rnn.param.grad)
  

В первую эпоху первая печать получает число, не являющееся нулями, затем в каждую другую эпоху оно равно 0.0, а вторая инструкция печати показывает, что значение остается неизменным, а grad всегда не равен нулю.

Комментарии:

1. Является ли это реальным кодом? Возникает ошибка копирования-вставки super(RNNetwork, self) .

2. Нет, это не фактический код, он был значительно сокращен, оригинал довольно большой. Спасибо, что указали на ошибку в примере