аргументы и вызов функции LSTM в pytorch

#machine-learning #deep-learning #pytorch #lstm #recurrent-neural-network

Вопрос:

Не мог бы кто-нибудь объяснить мне, пожалуйста, приведенный ниже код:

 import torch
import torch.nn as nn

input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)

rnn = nn.LSTM(10,20,2)

output, (hn, cn) = rnn(input, (h0, c0))
print(input)
 

При вызове rnn rnn(input, (h0, c0)) мы привели аргументы h0 и c0 в круглых скобках. Что это должно означать? если (h0, c0) представляет одно значение, то что это за значение и какой третий аргумент передается здесь?
Однако в строке rnn = nn.LSTM(10,20,2) мы передаем аргументы в функции LSTM без парантезиса.
Может ли кто-нибудь объяснить мне, как работает этот вызов функции?

Комментарии:

1. Скобки в python используются для создания кортежей, вы должны прочитать об этом в документации по python.

Ответ №1:

Назначение rnn = nn.LSTM(10, 20, 2) создает экземпляр нового nn.Module с использованием nn.LSTM класса. Это первые три аргумента input_size (здесь 10 ), hidden_size (здесь 20 ) и num_layers (здесь 2 ).

С другой стороны rnn(input, (h0, c0)) , это соответствует фактическому вызову экземпляра класса, i.e. запуск __call__ которого примерно эквивалентен forward функции этого модуля. __call__ Метод nn.LSTM принимает два параметра: input (фигурный (sequnce_length, batch_size, input_size) и кортеж из двух тензоров (h_0, c_0) (оба имеют форму (num_layers, batch_size, hidden_size) в базовом случае использования nn.LSTM )

Пожалуйста, обратитесь к документации PyTorch всякий раз, когда вы используете встроенные модули, вы найдете точное определение списка параметров (аргументы, используемые для инициализации экземпляра класса), а также спецификации ввода/вывода (при выводе с помощью указанного модуля).


Вы можете быть смущены обозначением, вот небольшой пример, который может помочь:

  • кортеж в качестве входных данных:
     def fn1(x, p):
        a, b = p # unpack input
        return a*x   b
    
    >>> fn1(2, (3, 1))
    >>> 7
     
  • кортеж в качестве вывода
     def fn2(x):
        return x, (3*x, x**2) # actually output is a tuple of int and tuple 
    
    >>> x, (a, b) = fn2(2) # unpacking
    (2, (6, 4))
    
    >>> x, a, b
    (2, 6, 4)
     

Комментарии:

1. Большое вам спасибо за разъяснения 🙂

2. Кроме того, не могли бы вы объяснить, пожалуйста, как вы пришли к такому выводу, что__вызов__ — это тот, который принимает параметры? Я прочитал исходный код pytorch.org/docs/stable/_modules/torch/nn/modules/rnn.html#LSTM по этой ссылке и я не могу связаться. Хотя я совершенно ясно понял ваше объяснение. Я хочу знать, как ты это выяснил? Заранее спасибо 🙂

3. __call__ Действительно, не появляется в torch/nn/modules/rnn.py, это связано с тем , что метод реализован в суперклассе nn.Module , а не в дочернем классе nn.LSTM (обратите внимание, что в коде, который вы связали с class LSTM(RNNBase) L470 и class RNNBase(Module) L24: иерархия классов LSTM < RNNBase < Module . Последнее является тем, что реализуется __call__ здесь. Дайте мне знать, если вам это ясно!

4. Я забыл упомянуть, __call__ что метод представляет собой специальную функцию Python, которая позволяет вызывать экземпляр класса (как вы делаете с функциями). В свою очередь, если у вас есть класс A , определенный вместе с __call__ определением, то вы можете вызвать его экземпляры: a = A(); a() . Если этот конкретный метод определен в суперклассе nn.Module , все nn.Module экземпляры являются вызываемыми объектами…

5. Большое вам спасибо за подробное объяснение 🙂