Разъяснение в понимании TorchScripts и JIT на PyTorch

#python #pytorch #jit #torchscript #tvm

#python #pytorch #jit #torchscript #tvm

Вопрос:

Просто хотел прояснить мое понимание того, как работают JIT и TorchScripts, и пояснить конкретный пример.

Итак, если я не ошибаюсь torch.jit.script , преобразует мой метод или модуль в TorchScript. Я могу использовать свой скомпилированный модуль TorchScript в среде вне python, но также могу просто использовать его в python с предполагаемыми улучшениями и оптимизациями. Аналогичный случай, torch.jit.trace когда вместо этого отслеживаются веса и операции, но следует примерно аналогичной идее.

Если это так, модуль с TorchScripted должен, как правило, быть по крайней мере таким же быстрым, как типичное время вывода интерпретатора python. Немного поэкспериментировав, я заметил, что он чаще всего медленнее, чем типичное время вывода интерпретатора, и, немного почитав, обнаружил, что, по-видимому, модуль TorchScripted необходимо немного «разогреть», чтобы достичь его наилучшей производительности. При этом я не увидел никаких изменений как таковых во времени вывода, стало лучше, но недостаточно, чтобы вызвать улучшение по сравнению с типичным способом выполнения действий (интерпретатор python). Кроме того, я использовал стороннюю библиотеку called torch_tvm , которая при включении предположительно вдвое сокращает время вывода для любого способа jit-ing модуля.

Ничего из этого не происходило до сих пор, и я действительно не могу сказать, почему.

Ниже приведен мой пример кода на случай, если я что-то сделал неправильно —

 class TrialC(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(1024, 2048)
        self.l2 = nn.Linear(2048, 4096)
        self.l3 = nn.Linear(4096, 4096)
        self.l4 = nn.Linear(4096, 2048)
        self.l5 = nn.Linear(2048, 1024)

    def forward(self, input):
        out = self.l1(input)
        out = self.l2(out)
        out = self.l3(out)
        out = self.l4(out)
        out = self.l5(out)
        return out 

if __name__ == '__main__':
    # Trial inference input 
    TrialC_input = torch.randn(1, 1024)
    warmup = 10

    # Record time for typical inference 
    model = TrialC()
    start = time.time()
    model_out = model(TrialC_input)
    elapsed = time.time() - start 

    # Record the 10th inference time (10 warmup) for the optimized model in TorchScript 
    script_model = torch.jit.script(TrialC())
    for i in range(warmup):
        start_2 = time.time()
        model_out_check_2 = script_model(TrialC_input)
        elapsed_2 = time.time() - start_2

    # Record the 10th inference time (10 warmup) for the optimized model in TorchScript   TVM optimization
    torch_tvm.enable()
    script_model_2 = torch.jit.trace(TrialC(), torch.randn(1, 1024))
    for i in range(warmup):
        start_3 = time.time()
        model_out_check_3 = script_model_2(TrialC_input)
        elapsed_3 = time.time() - start_3 
    
    print("Regular model inference time: {}snJIT compiler inference time: {}snJIT Compiler with TVM: {}s".format(elapsed, elapsed_2, elapsed_3))
  

И ниже приведены результаты приведенного выше кода на моем процессоре —

 Regular model inference time: 0.10335588455200195s
JIT compiler inference time: 0.11449170112609863s
JIT Compiler with TVM: 0.10834860801696777s
  

Любая помощь или ясность по этому вопросу будут действительно оценены!