#python #pytorch #torch #autograd
Вопрос:
У меня есть модель, которая выводит последовательность векторов для каждого элемента в пакете, например [Batch size, Sequence Length, Hidden size]
. Затем я хочу выбрать переменное число векторов для каждого элемента в пакете и скопировать эти векторы в тензор, где requires_grad = True
. Пример кода приведен ниже:
from torch import nn
from typing import List
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(8,8)
def forward(self, x: torch.Tensor, indices: List[torch.Tensor]):
# Example indices: [torch.tensor([0,1]), torch.tensor([2,3,4])]
out = self.fc(x)
batch_size, _, hidden_size = out.size()
max_num_hidden_states = max([ind.size(0) for ind in indices])
selected_hidden_states = torch.zeros(batch_size, max_num_hidden_states, hidden_size, requires_grad=True)
for i in range(batch_size):
selected_hidden_states.data[i, :indices[i].size(0)] = out[i, indices[i]]
return selected_hidden_states
model = MyModel()
with torch.no_grad():
output = model(torch.rand(2, 5, 8), [torch.tensor([0,1]), torch.tensor([2,3,4])])
Вопросы, которые у меня возникают в связи с этим, следующие:
- Если я обучу такую модель, будут ли градиенты распространяться обратно в остальных параметрах модели?
- Почему
output.requires_grad = True
, когда я прямо заявляюtorch.no_grad()
? - То, как я это делаю (что, похоже, сейчас работает не так, как ожидалось), кажется слишком банальным и неправильным. Каков правильный способ достичь того, чего я хочу?
Я знаю этот ответ, который одобряет мой способ делать это (по крайней мере, мне так кажется), но все равно он кажется мне банальным.
Ваше здоровье!
Комментарии:
1.
requires_grad=True
этого будет недостаточно, чтобы сделать вывод вашей модели доступным для обратного распространения. Операторы факела должны связать его с параметрами вашей модели, чего здесь нет.
Ответ №1:
Скопируйте и вставьте ответ с форума PyTorch.
Это из давно минувших времен и ответ на другой вопрос.
- Нет. Между созданием нового тензора, требующего градации, и использованием .data, чего в наши дни никогда не следует делать, вы создали новый лист, который будет накапливать .grad.
- Поскольку вы запросили его. no_grad сигнализирует о том, что вам не нужен градус, он не содержит гарантий относительно требуемого градуса результата.
- Если функция утилиты не работает для вас, удаление requires_grad и .data должно помочь.