#python #class #pytorch
#python #класс #pytorch
Вопрос:
У меня есть некоторые знания в Pytorch, но я не совсем понимаю механизмы классов в Pytorch. Например, в ссылке: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html вы можете найти следующий код:
import torch
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
я фокусируюсь только на методе forward этого класса, и мне интересно, что
ctx.save_for_backward(input)
делает.Присутствует ли предыдущая строка кода или нет, не имеет значения, когда я пытаюсь использовать метод forward на конкретном примере:
a=torch.eye(3)
rel=MyReLU()
print(rel.forward(rel,a))
поскольку я получаю одинаковый результат в обоих случаях.Может кто-нибудь объяснить мне, что происходит и почему полезно добавить save_for_backward?
Заранее благодарю вас.
Ответ №1:
ctx.save_for_backward
Метод используется для хранения значений, сгенерированных во forward()
время выполнения, которые понадобятся позже при выполнении backward()
. Доступ к сохраненным значениям можно получить во backward()
время из ctx.saved_tensors
атрибута.
Комментарии:
1. большое вам спасибо за ваш ответ. Однако у меня есть еще один вопрос: как методы save_for_backward и saved_tensors связаны друг с другом? Я не нахожу никакого кода, определяющего эти методы на веб-сайте Pytorch.
2. каков эквивалентный метод save_for_backward() в tensorflow?