Вычисление промежуточных градиентов с использованием обратного метода в Pytorch

#pytorch

#pytorch

Вопрос:

У меня возникли проблемы с пониманием backward метода в pytorch

 x1 = tensor(2.).requires_grad_()
x2 = tensor(3.).requires_grad_() # or x2 = tensor(3.)
x3 = x1   x2

l = (x3**2).sum()
l.backward()

print(x1)
print(x3)
print(x1.grad)
print(x3.grad)
 

Результаты

 tensor(2., requires_grad=True)
tensor(5., grad_fn=<AddBackward0>)
tensor(10.)
None
 

Почему все x3.grad еще None ? Разве это не должно быть tensor(10.) ?

Когда я запускаю следующие строки кода, x3.grad вычисляется tensor(10.)

 x3 = tensor(5.).requires_grad_()
l = (x3**2).mean()
l.backward()
print(x3.grad)
 

Ответ №1:

Если вы печатаете x3.grad в своем первом примере, вы можете заметить, что torch выводит предупреждение:

Предупреждение пользователя: .grad осуществляется доступ к атрибуту тензора, который не является конечным тензором. Его .grad атрибут не будет заполнен во autograd.backward() время. Если вам действительно нужен градиент для нелистового тензора, используйте .retain_grad() для нелистового тензора. Если вы по ошибке обращаетесь к тензору, отличному от листа, убедитесь, что вместо этого вы обращаетесь к тензору листа. Смотрите здесь для получения дополнительной информации.

Для экономии памяти градиенты нелистовых тензоров (тензоров, не созданных пользователем) не буферизуются.

Если вы хотите увидеть эти градиенты, вы можете сохранить градиент включенным x3 , вызвав .retain_grad() перед созданием графика (т. Е. Перед вызовом .backward() .

 x3.retain_grad()
l.backward()
print(x3.grad)
 

действительно выведет tensor(10.)

Комментарии:

1. Спасибо @Ivan. Что значит быть не-листовым (или не созданным пользователем) тензором?

2. созданный пользователем , как в не определен буквально (как x1 и x2 есть). Смотрите, x3 не является листовым, потому что оно зависит от других тензоров на графике. На обратном проходе поток сначала проходит через l then x3 и, наконец, находит x1 and x2 , два листа вашего графика.

3. Кажется, я понял. Несмотря на то, что pytorch использовал x3 в графе вычислений, он не сохранил градиенты на промежуточных узлах для экономии памяти. Большое спасибо @Ivan