почему учебное пособие по автокодированию pytorch изменяет представление вывода встроенного слоя?

#python #pytorch #autoencoder

#python #pytorch #автоэнкодер

Вопрос:

Как показано здесь в учебниках PyTorch, код для модели автоэнкодирования выглядит следующим образом:

 class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
 

Мой вопрос в том, в чем причина использования view функции при выводе embedding слоя?

Ответ №1:

Функция просмотра добавила дополнительное измерение к заданной форме ввода, чтобы соответствовать ожидаемой форме ввода. В функции initHidden инициализируется скрытая фигура (1, 1, 256) .

 def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
 

На основании документации форма ввода GRU должна иметь 3 измерения, input of shape (seq_len, batch, input_size) .

https://pytorch.org/docs/stable/generated/torch.nn.GRU.html

Форма self.embedding(input) (1, 256) и образец вывода,

 tensor([[ 0.1421,  0.4135, -1.0619,  0.0149,  0.0673, -0.3770,  0.4231,  2.2803,
         -1.6939, -0.0071,  1.1131, -1.0019,  0.6593,  0.1366,  1.1033, -0.8804,
          1.3676,  0.4115, -0.5671,  0.3314, -0.2599, -0.3082,  1.3644,  0.5788,
         -0.1929, -2.0505,  0.4518,  0.8757, -0.2360, -0.4099, -0.5697, -1.5973,
         -0.6638, -1.1523,  1.4425,  1.3651,  1.9371,  0.5698, -0.3541, -1.3883,
         -0.0195, -1.0757, -1.4324, -1.6226, -2.4267,  0.3874, -0.7529,  1.4938,
         -2.5773, -1.1962,  0.3759, -0.6143, -1.0444, -0.6443, -0.8130, -1.7283,
          1.4167,  1.3945, -1.2695,  0.7289,  0.7777, -0.0094, -1.8108,  0.2126,
         -0.2018, -0.4055, -0.7779, -0.8523,  0.0162,  0.2463,  0.5588, -0.7250,
         -0.0128,  0.6272, -0.7729,  0.4259,  0.7596, -1.9500,  0.5853,  0.3764,
         -0.1112,  0.7274, -2.8535, -0.0445,  0.4225,  1.2179,  0.2219, -0.7064,
         -0.9654,  1.0501,  1.7142,  0.5312, -0.8180, -1.5697,  1.3062, -0.9321,
         -0.1652, -1.5298, -0.3575, -1.2046, -0.6571, -0.7689, -0.7032,  1.0727,
         -1.3259,  0.1200,  1.9357, -0.2519, -0.3717,  0.8054,  0.1180, -0.6921,
          1.0245, -1.5500, -0.5280, -0.7462,  0.7924,  2.2701, -1.5094, -0.1973,
         -1.5919,  0.4869,  0.6739, -0.5242,  0.2559, -0.0149, -0.5332, -1.8313,
          0.3598,  0.0804, -0.0780, -0.2930, -0.2844, -0.4752, -0.9919,  0.1809,
          0.7622, -2.5069, -0.7724, -0.9441,  1.6101,  0.6461, -0.8932,  0.0600,
          0.6911,  0.5191, -0.1719, -0.5829, -0.9168,  1.5282,  1.4399,  0.3264,
         -0.8894,  0.2880, -0.0697,  0.8977, -0.5004,  0.3844,  0.0925,  0.5592,
         -0.1664,  0.8575, -1.0348,  0.7326, -0.2124,  0.7533,  0.6270, -0.9559,
         -1.4159,  0.6788,  0.6163, -0.5951, -0.1403, -1.6088, -0.7731,  0.3876,
          1.0429, -2.0960,  0.1726,  1.7446, -0.3963,  0.0785, -0.4701,  1.0074,
          0.3319, -2.2675, -1.6163, -0.4003, -0.5468,  0.0452, -2.5586,  0.4747,
         -0.0271, -1.2161,  1.2121,  1.8738, -1.2207, -0.9218, -0.1430,  0.2512,
         -0.5236, -0.2544, -0.5868, -0.7086, -1.3328, -0.0243,  0.4759,  1.4125,
          0.4947,  0.5054,  1.6253,  0.4198, -0.9150,  0.6374,  0.4581,  1.1527,
          1.4440, -0.0590, -0.4601,  0.2490, -0.5739,  0.6798, -0.2156, -1.1386,
         -0.5011, -0.7411,  0.2825, -0.2595,  0.8070,  0.5270,  0.2595, -0.1089,
          0.4221, -0.7851,  0.7112, -0.3038,  0.6169, -0.1513, -0.5872,  0.3974,
          0.2431,  0.4934, -0.9406, -0.9372,  1.4525,  0.1376,  0.2558,  0.0661,
          0.3509,  2.1667,  2.8428,  0.9429, -0.6143, -1.0969,  0.0955,  0.0914]],
       device='cuda:0', grad_fn=<EmbeddingBackward>)
 

Форма self.embedding(input).view(1, 1, -1) (1, 1, 256) и образец вывода,

 tensor([[[ 0.1421,  0.4135, -1.0619,  0.0149,  0.0673, -0.3770,  0.4231,
           2.2803, -1.6939, -0.0071,  1.1131, -1.0019,  0.6593,  0.1366,
           1.1033, -0.8804,  1.3676,  0.4115, -0.5671,  0.3314, -0.2599,
          -0.3082,  1.3644,  0.5788, -0.1929, -2.0505,  0.4518,  0.8757,
          -0.2360, -0.4099, -0.5697, -1.5973, -0.6638, -1.1523,  1.4425,
           1.3651,  1.9371,  0.5698, -0.3541, -1.3883, -0.0195, -1.0757,
          -1.4324, -1.6226, -2.4267,  0.3874, -0.7529,  1.4938, -2.5773,
          -1.1962,  0.3759, -0.6143, -1.0444, -0.6443, -0.8130, -1.7283,
           1.4167,  1.3945, -1.2695,  0.7289,  0.7777, -0.0094, -1.8108,
           0.2126, -0.2018, -0.4055, -0.7779, -0.8523,  0.0162,  0.2463,
           0.5588, -0.7250, -0.0128,  0.6272, -0.7729,  0.4259,  0.7596,
          -1.9500,  0.5853,  0.3764, -0.1112,  0.7274, -2.8535, -0.0445,
           0.4225,  1.2179,  0.2219, -0.7064, -0.9654,  1.0501,  1.7142,
           0.5312, -0.8180, -1.5697,  1.3062, -0.9321, -0.1652, -1.5298,
          -0.3575, -1.2046, -0.6571, -0.7689, -0.7032,  1.0727, -1.3259,
           0.1200,  1.9357, -0.2519, -0.3717,  0.8054,  0.1180, -0.6921,
           1.0245, -1.5500, -0.5280, -0.7462,  0.7924,  2.2701, -1.5094,
          -0.1973, -1.5919,  0.4869,  0.6739, -0.5242,  0.2559, -0.0149,
          -0.5332, -1.8313,  0.3598,  0.0804, -0.0780, -0.2930, -0.2844,
          -0.4752, -0.9919,  0.1809,  0.7622, -2.5069, -0.7724, -0.9441,
           1.6101,  0.6461, -0.8932,  0.0600,  0.6911,  0.5191, -0.1719,
          -0.5829, -0.9168,  1.5282,  1.4399,  0.3264, -0.8894,  0.2880,
          -0.0697,  0.8977, -0.5004,  0.3844,  0.0925,  0.5592, -0.1664,
           0.8575, -1.0348,  0.7326, -0.2124,  0.7533,  0.6270, -0.9559,
          -1.4159,  0.6788,  0.6163, -0.5951, -0.1403, -1.6088, -0.7731,
           0.3876,  1.0429, -2.0960,  0.1726,  1.7446, -0.3963,  0.0785,
          -0.4701,  1.0074,  0.3319, -2.2675, -1.6163, -0.4003, -0.5468,
           0.0452, -2.5586,  0.4747, -0.0271, -1.2161,  1.2121,  1.8738,
          -1.2207, -0.9218, -0.1430,  0.2512, -0.5236, -0.2544, -0.5868,
          -0.7086, -1.3328, -0.0243,  0.4759,  1.4125,  0.4947,  0.5054,
           1.6253,  0.4198, -0.9150,  0.6374,  0.4581,  1.1527,  1.4440,
          -0.0590, -0.4601,  0.2490, -0.5739,  0.6798, -0.2156, -1.1386,
          -0.5011, -0.7411,  0.2825, -0.2595,  0.8070,  0.5270,  0.2595,
          -0.1089,  0.4221, -0.7851,  0.7112, -0.3038,  0.6169, -0.1513,
          -0.5872,  0.3974,  0.2431,  0.4934, -0.9406, -0.9372,  1.4525,
           0.1376,  0.2558,  0.0661,  0.3509,  2.1667,  2.8428,  0.9429,
          -0.6143, -1.0969,  0.0955,  0.0914]]], device='cuda:0',
       grad_fn=<ViewBackward>)
 

Код

Этот код работает,

 rnn1 = nn.GRU(256, 128, 1)
input1 = torch.randn(100, 2, 256)
h01 = torch.randn(1, 2, 128)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)
 

Вывод

 torch.Size([100, 2, 256]) torch.Size([1, 2, 128])
torch.Size([100, 2, 128]) torch.Size([1, 2, 128])
 

Код

Этот код также работает,

 rnn1 = nn.GRU(256, 256)
input1 = torch.randn(1, 1, 256)
h01 = torch.randn(1, 1, 256)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)
 

Вывод

 torch.Size([1, 1, 256]) torch.Size([1, 1, 256])
torch.Size([1, 1, 256]) torch.Size([1, 1, 256])
 

Код

Это не работает,

 rnn1 = nn.GRU(256, 256)
input1 = torch.randn(1, 256)
#input1 = input1.view(1, 1, -1)
h01 = torch.randn(1, 1, 256)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)
 

Вывод

 RuntimeError: input must have 3 dimensions, got 2