torch.jit.script (модуль) против декоратора @torch.jit.script

#pytorch #torchscript

#pytorch #torchscript

Вопрос:

Почему добавление декоратора «@torch.jit.script» приводит к ошибке, в то время как я могу вызвать torch.jit.script в этом модуле, например, это сбой:

 import torch
    
@torch.jit.script
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x)   h)
        return new_h, new_h
    
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
  
 "C:UsersAdministratorAppDataLocalPackagesPythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0LocalCachelocal-packagesPython38site-packagestorchjit__init__.py", line 1262, in script
    raise RuntimeError("Type '{}' cannot be compiled since it inherits"
RuntimeError: Type '<class '__main__.MyCell'>' cannot be compiled since it inherits from nn.Module, pass an instance instead
  

Хотя следующий код работает хорошо:

 class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
    
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x)   h)
        return new_h, new_h
    
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
  

Этот вопрос также обсуждается на форумах PyTorch.

Комментарии:

1. Если вы получите более подробное объяснение на форумах PyTorch, пожалуйста, разместите его здесь в качестве самостоятельного ответа, спасибо.

Ответ №1:

Причина вашей ошибки здесь, именно в этом пункте:

Нет поддержки наследования или любой другой стратегии полиморфизма, за исключением наследования от object для указания класса нового стиля.

Кроме того, как указано вверху:

Поддержка класса TorchScript является экспериментальной. В настоящее время он лучше всего подходит для простых типов, похожих на записи (например, NamedTuple с прикрепленными методами).

В настоящее время он предназначен для простых классов на Python (см. Другие пункты по ссылке, которую я предоставил) и функций, смотрите ссылку, которую я предоставил для получения дополнительной информации.

Вы также можете проверить torch.jit.script исходный код, чтобы лучше понять, как это работает.

Судя по всему, когда вы передаете экземпляр, все attributes , которые должны быть сохранены, рекурсивно анализируются (исходный код). Вы можете следовать этой функции (довольно прокомментированной, но слишком длинной для ответа, смотрите Здесь), хотя точная причина, почему это так (и почему это было разработано таким образом), находится за пределами моих знаний (так что, надеюсь, кто-то с опытом во внутренней работе torch.jit расскажет об этом подробнее).