Почему, изменив тензор с помощью метода отсоединения, обратное распространение не всегда может работать в pytorch?

#graph #pytorch #tensor #detach

Вопрос:

После построения графика при использовании detach метода для изменения некоторого значения тензора ожидается, что при вычислении обратного распространения появится ошибка. Однако это не всегда так. В следующих двух блоках кода: первый выдает ошибку, а второй-нет. Почему это происходит?

 x = torch.tensor(3.0, requires_grad=True)
y = x   1
z = y**2
c = y.detach()
c.zero_()
z.backward(retain_graph=True)
print(x.grad) # errors pop up
 
 x = torch.tensor(3.0, requires_grad=True)
y1 = x 1
y2 = x**2
z = 3*y1   4*y2
c = y2.detach()
c.zero_()
z.backward(retain_graph=True)
print(x.grad) # no errors. The printed value is 27
 

Ответ №1:

TLDR; В предыдущем примере z = y**2 , so dz/dy = 2*y , т. Е. это функция y и требует, чтобы ее значения оставались неизменными для правильного вычисления обратного распространения, отсюда и сообщение об ошибке при применении операции на месте. В последнем z = 3*y1 4*y2 случае , таким dz/dy2 = 4 образом , т. е. y2 значения не нужны для вычисления градиента, так как такие его значения могут быть свободно изменены.


  • В первом примере у вас есть следующий график вычислений:
     x ---> y = x   1 ---> z = y**2
            
              ---> c = y.detach().zero_()
     

    Соответствующий код:

     x = torch.tensor(3.0, requires_grad=True)
    y = x   1
    z = y**2
    c = y.detach()
    c.zero_()
    z.backward() # errors pop up
     

    При вызове c = y.detach() вы эффективно отсоединяетесь c от графика вычислений, y оставаясь при этом прикрепленным. Однако c использует те же данные, y что и . Это означает , что, когда вы вызываете операцию на месте c.zero_ , вы в конечном итоге оказываете влияние y . Это недопустимо, поскольку y это часть графика вычислений, и его значения будут необходимы для потенциального обратного распространения из переменной z .


  • Второй сценарий соответствует этой схеме:
       /--> y1 = x   1 
    x                  ---> z = 3*y1   4*y2
      --> y2 = x**2  /
             
               ---> c = y2.detach().zero_()
     

    Соответствующий код:

     x = torch.tensor(3.0, requires_grad=True)
    y1 = x   1
    y2 = x**2
    z = 3*y1   4*y2
    c = y2.detach()
    c.zero_()
    z.backward()
    print(x.grad) # no errors. The printed value is 27
     

    Здесь снова у нас та же настройка, вы отсоединяете, затем изменяете на месте c и y с zero_ помощью .


Единственное различие заключается в операции, выполняемой на y и y2 (в 1-м и 2-м примере соответственно).

  • В первом случае у вас есть z = y**2 , поэтому производная есть 2*y , следовательно, значение y необходимо для вычисления градиента этой операции.
  • В последнем примере , однако z(y2) = constant 4*y2 , производная по отношению к y2 является просто константой: 4 , т. Е. y2 для вычисления ее производной не требуется значение. Вы можете проверить это, например, определив во 2-м примере z с z = 3*y1 4*y2**2 : это вызовет ошибку.

Комментарии:

1. Большое спасибо за подробное объяснение, которое очень понятно.