Почему атрибуты тензорного объекта удаляются при клонировании?

#pytorch

Вопрос:

Я пытаюсь клонировать тензор в pytorch и хотел бы также клонировать атрибуты тензора. Вот пример:

 import torch
from torch import nn

a = nn.Parameter(torch.rand(1))
a.adapt = True                      # define tensor attribute

b = a.clone()                       # clone
 

В приведенном выше примере я хотел print(b.adapt) бы вернуться True ; однако я получаю следующую ошибку:

 Traceback (most recent call last):
  File "scratch.py", line 13, in <module>
    print(b.adapt)
AttributeError: 'Tensor' object has no attribute 'adapt'
 

Мне интересно, почему атрибуты тензорного объекта удаляются при клонировании и как это исправить.

Ответ №1:

Функция torch.Tensor.clone выполняет копию данных тензора, а не копию объекта Python. Это причина, по которой атрибут adapt of a недоступен b . Кроме того, он сохранит то же grad_fn самое для вновь созданного тензора:

Комментарии:

1. Спасибо. Есть ли альтернатива .clone , чтобы исправить это? или единственный способ — просто добавить b.adapt = a.adapt ?

2. Если вы хотите сохранить график включенным b , то использование clone — единственный способ. Действительно, если вы хотите adapt продолжить b , вам придется скопировать его самостоятельно.