#graph #pytorch #tensor #detach
Вопрос:
После построения графика при использовании detach
метода для изменения некоторого значения тензора ожидается, что при вычислении обратного распространения появится ошибка. Однако это не всегда так. В следующих двух блоках кода: первый выдает ошибку, а второй-нет. Почему это происходит?
x = torch.tensor(3.0, requires_grad=True)
y = x 1
z = y**2
c = y.detach()
c.zero_()
z.backward(retain_graph=True)
print(x.grad) # errors pop up
x = torch.tensor(3.0, requires_grad=True)
y1 = x 1
y2 = x**2
z = 3*y1 4*y2
c = y2.detach()
c.zero_()
z.backward(retain_graph=True)
print(x.grad) # no errors. The printed value is 27
Ответ №1:
TLDR; В предыдущем примере z = y**2
, so dz/dy = 2*y
, т. Е. это функция y
и требует, чтобы ее значения оставались неизменными для правильного вычисления обратного распространения, отсюда и сообщение об ошибке при применении операции на месте. В последнем z = 3*y1 4*y2
случае , таким dz/dy2 = 4
образом , т. е. y2
значения не нужны для вычисления градиента, так как такие его значения могут быть свободно изменены.
- В первом примере у вас есть следующий график вычислений:
x ---> y = x 1 ---> z = y**2 ---> c = y.detach().zero_()
Соответствующий код:
x = torch.tensor(3.0, requires_grad=True) y = x 1 z = y**2 c = y.detach() c.zero_() z.backward() # errors pop up
При вызове
c = y.detach()
вы эффективно отсоединяетесьc
от графика вычислений,y
оставаясь при этом прикрепленным. Однакоc
использует те же данные,y
что и . Это означает , что, когда вы вызываете операцию на местеc.zero_
, вы в конечном итоге оказываете влияниеy
. Это недопустимо, посколькуy
это часть графика вычислений, и его значения будут необходимы для потенциального обратного распространения из переменнойz
.
- Второй сценарий соответствует этой схеме:
/--> y1 = x 1 x ---> z = 3*y1 4*y2 --> y2 = x**2 / ---> c = y2.detach().zero_()
Соответствующий код:
x = torch.tensor(3.0, requires_grad=True) y1 = x 1 y2 = x**2 z = 3*y1 4*y2 c = y2.detach() c.zero_() z.backward() print(x.grad) # no errors. The printed value is 27
Здесь снова у нас та же настройка, вы отсоединяете, затем изменяете на месте
c
иy
сzero_
помощью .
Единственное различие заключается в операции, выполняемой на y
и y2
(в 1-м и 2-м примере соответственно).
- В первом случае у вас есть
z = y**2
, поэтому производная есть2*y
, следовательно, значениеy
необходимо для вычисления градиента этой операции. - В последнем примере , однако
z(y2) = constant 4*y2
, производная по отношению кy2
является просто константой:4
, т. Е.y2
для вычисления ее производной не требуется значение. Вы можете проверить это, например, определив во 2-м примереz
сz = 3*y1 4*y2**2
: это вызовет ошибку.
Комментарии:
1. Большое спасибо за подробное объяснение, которое очень понятно.