#machine-learning #deep-learning #pytorch #lstm #recurrent-neural-network
Вопрос:
Не мог бы кто-нибудь объяснить мне, пожалуйста, приведенный ниже код:
import torch
import torch.nn as nn
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
rnn = nn.LSTM(10,20,2)
output, (hn, cn) = rnn(input, (h0, c0))
print(input)
При вызове rnn rnn(input, (h0, c0))
мы привели аргументы h0 и c0 в круглых скобках. Что это должно означать? если (h0, c0) представляет одно значение, то что это за значение и какой третий аргумент передается здесь?
Однако в строке rnn = nn.LSTM(10,20,2)
мы передаем аргументы в функции LSTM без парантезиса.
Может ли кто-нибудь объяснить мне, как работает этот вызов функции?
Комментарии:
1. Скобки в python используются для создания кортежей, вы должны прочитать об этом в документации по python.
Ответ №1:
Назначение rnn = nn.LSTM(10, 20, 2)
создает экземпляр нового nn.Module
с использованием nn.LSTM
класса. Это первые три аргумента input_size
(здесь 10
), hidden_size
(здесь 20
) и num_layers
(здесь 2
).
С другой стороны rnn(input, (h0, c0))
, это соответствует фактическому вызову экземпляра класса, i.e.
запуск __call__
которого примерно эквивалентен forward
функции этого модуля. __call__
Метод nn.LSTM
принимает два параметра: input
(фигурный (sequnce_length, batch_size, input_size)
и кортеж из двух тензоров (h_0, c_0)
(оба имеют форму (num_layers, batch_size, hidden_size)
в базовом случае использования nn.LSTM
)
Пожалуйста, обратитесь к документации PyTorch всякий раз, когда вы используете встроенные модули, вы найдете точное определение списка параметров (аргументы, используемые для инициализации экземпляра класса), а также спецификации ввода/вывода (при выводе с помощью указанного модуля).
Вы можете быть смущены обозначением, вот небольшой пример, который может помочь:
- кортеж в качестве входных данных:
def fn1(x, p): a, b = p # unpack input return a*x b >>> fn1(2, (3, 1)) >>> 7
- кортеж в качестве вывода
def fn2(x): return x, (3*x, x**2) # actually output is a tuple of int and tuple >>> x, (a, b) = fn2(2) # unpacking (2, (6, 4)) >>> x, a, b (2, 6, 4)
Комментарии:
1. Большое вам спасибо за разъяснения 🙂
2. Кроме того, не могли бы вы объяснить, пожалуйста, как вы пришли к такому выводу, что__вызов__ — это тот, который принимает параметры? Я прочитал исходный код pytorch.org/docs/stable/_modules/torch/nn/modules/rnn.html#LSTM по этой ссылке и я не могу связаться. Хотя я совершенно ясно понял ваше объяснение. Я хочу знать, как ты это выяснил? Заранее спасибо 🙂
3.
__call__
Действительно, не появляется в torch/nn/modules/rnn.py, это связано с тем , что метод реализован в суперклассеnn.Module
, а не в дочернем классеnn.LSTM
(обратите внимание, что в коде, который вы связали сclass LSTM(RNNBase)
L470 иclass RNNBase(Module)
L24: иерархия классовLSTM
<RNNBase
<Module
. Последнее является тем, что реализуется__call__
здесь. Дайте мне знать, если вам это ясно!4. Я забыл упомянуть,
__call__
что метод представляет собой специальную функцию Python, которая позволяет вызывать экземпляр класса (как вы делаете с функциями). В свою очередь, если у вас есть классA
, определенный вместе с__call__
определением, то вы можете вызвать его экземпляры:a = A(); a()
. Если этот конкретный метод определен в суперклассеnn.Module
, всеnn.Module
экземпляры являются вызываемыми объектами…5. Большое вам спасибо за подробное объяснение 🙂