#pytorch
Вопрос:
Я пытаюсь клонировать тензор в pytorch и хотел бы также клонировать атрибуты тензора. Вот пример:
import torch
from torch import nn
a = nn.Parameter(torch.rand(1))
a.adapt = True # define tensor attribute
b = a.clone() # clone
В приведенном выше примере я хотел print(b.adapt)
бы вернуться True
; однако я получаю следующую ошибку:
Traceback (most recent call last): File "scratch.py", line 13, in <module> print(b.adapt) AttributeError: 'Tensor' object has no attribute 'adapt'
Мне интересно, почему атрибуты тензорного объекта удаляются при клонировании и как это исправить.
Ответ №1:
Функция torch.Tensor.clone
выполняет копию данных тензора, а не копию объекта Python. Это причина, по которой атрибут adapt of a
недоступен b
. Кроме того, он сохранит то же grad_fn
самое для вновь созданного тензора:
Комментарии:
1. Спасибо. Есть ли альтернатива
.clone
, чтобы исправить это? или единственный способ — просто добавитьb.adapt = a.adapt
?2. Если вы хотите сохранить график включенным
b
, то использованиеclone
— единственный способ. Действительно, если вы хотитеadapt
продолжитьb
, вам придется скопировать его самостоятельно.