Как сделать замаскированный знак в ПыТорхе?

#pytorch

Вопрос:

Это прямой проход двунаправленного rnn, в котором я хочу взять пул средних выходных функций. Как вы можете видеть, я пытаюсь исключить временные шаги с помощью маркера pad из расчета.

 def forward(self, text):
    # text is shape (B, L)
    embed = self.embed(text)
    rnn_out, _ = self.rnn(embed)  # (B, L, 2*H)
    # Calculate average ignoring the pad token
    with torch.no_grad():
        rnn_out[text == self.pad_token] *= 0
        denom = torch.sum(text != self.pad_token, -1, keepdim=True)
    feat = torch.sum(rnn_out, dim=1) / denom
    feat =  self.dropout(feat)
    return feat
 

При обратном распространении возникает исключение из-за строки rnn_out[text == self.pad_token] *= 0 . Вот как это выглядит:

 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [32, 21, 128]], which is output 0 of CudnnRnnBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
 

Как правильно это сделать?

Примечание: Я знаю, что могу сделать это, выполнив и/или одно из следующих действий:

  • Укажите длину текста в качестве входных данных.
  • Выполните цикл по измерению пакета, найдя среднее значение для каждой последовательности, а затем сложите результат в стопку.

Но я хочу знать, есть ли более чистый способ не вовлекать их.

Ответ №1:

Вы изменяете вектор в контексте, в котором вы отключаете построение вычислительного графика (и вы изменяете его *= на месте), это приведет к хаосу при вычислении градиента. Вместо этого я бы предложил следующее:

 mask = text != self.pad_token
denom = torch.sum(mask, -1, keepdim=True)
feat = torch.sum(rnn_out * mask.unsqueeze(-1), dim=1) / denom
 

Возможно, вам придется немного подправить этот фрагмент, я не мог его протестировать, так как вы не предоставили полный пример, но, надеюсь, он покажет технику, которую вы можете использовать.

Комментарии:

1. Выглядит хорошо! Попробовал, и это не вызывает моей ошибки. Просто нужно .unsqueeze(-1) mask , чтобы в rnn_out * mask

2. Отлично, я соответствующим образом обновил фрагмент!