#pytorch
Вопрос:
Это прямой проход двунаправленного rnn, в котором я хочу взять пул средних выходных функций. Как вы можете видеть, я пытаюсь исключить временные шаги с помощью маркера pad из расчета.
def forward(self, text):
# text is shape (B, L)
embed = self.embed(text)
rnn_out, _ = self.rnn(embed) # (B, L, 2*H)
# Calculate average ignoring the pad token
with torch.no_grad():
rnn_out[text == self.pad_token] *= 0
denom = torch.sum(text != self.pad_token, -1, keepdim=True)
feat = torch.sum(rnn_out, dim=1) / denom
feat = self.dropout(feat)
return feat
При обратном распространении возникает исключение из-за строки rnn_out[text == self.pad_token] *= 0
. Вот как это выглядит:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [32, 21, 128]], which is output 0 of CudnnRnnBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Как правильно это сделать?
Примечание: Я знаю, что могу сделать это, выполнив и/или одно из следующих действий:
- Укажите длину текста в качестве входных данных.
- Выполните цикл по измерению пакета, найдя среднее значение для каждой последовательности, а затем сложите результат в стопку.
Но я хочу знать, есть ли более чистый способ не вовлекать их.
Ответ №1:
Вы изменяете вектор в контексте, в котором вы отключаете построение вычислительного графика (и вы изменяете его *=
на месте), это приведет к хаосу при вычислении градиента. Вместо этого я бы предложил следующее:
mask = text != self.pad_token
denom = torch.sum(mask, -1, keepdim=True)
feat = torch.sum(rnn_out * mask.unsqueeze(-1), dim=1) / denom
Возможно, вам придется немного подправить этот фрагмент, я не мог его протестировать, так как вы не предоставили полный пример, но, надеюсь, он покажет технику, которую вы можете использовать.
Комментарии:
1. Выглядит хорошо! Попробовал, и это не вызывает моей ошибки. Просто нужно
.unsqueeze(-1)
mask
, чтобы вrnn_out * mask
2. Отлично, я соответствующим образом обновил фрагмент!