Пытаюсь понять, что такое «save_for_backward» в Pytorch

#python #class #pytorch

#python #класс #pytorch

Вопрос:

У меня есть некоторые знания в Pytorch, но я не совсем понимаю механизмы классов в Pytorch. Например, в ссылке: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html вы можете найти следующий код:

 import torch

class MyReLU(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):       
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
  

я фокусируюсь только на методе forward этого класса, и мне интересно, что

 ctx.save_for_backward(input)
  

делает.Присутствует ли предыдущая строка кода или нет, не имеет значения, когда я пытаюсь использовать метод forward на конкретном примере:

 a=torch.eye(3)
rel=MyReLU()
print(rel.forward(rel,a))
  

поскольку я получаю одинаковый результат в обоих случаях.Может кто-нибудь объяснить мне, что происходит и почему полезно добавить save_for_backward?
Заранее благодарю вас.

Ответ №1:

ctx.save_for_backward Метод используется для хранения значений, сгенерированных во forward() время выполнения, которые понадобятся позже при выполнении backward() . Доступ к сохраненным значениям можно получить во backward() время из ctx.saved_tensors атрибута.

Комментарии:

1. большое вам спасибо за ваш ответ. Однако у меня есть еще один вопрос: как методы save_for_backward и saved_tensors связаны друг с другом? Я не нахожу никакого кода, определяющего эти методы на веб-сайте Pytorch.

2. каков эквивалентный метод save_for_backward() в tensorflow?