#pytorch
Вопрос:
Аналогичный вопрос уже задавался здесь, но я думаю, что решение не подходит для моего случая.
Я просто удивляюсь, почему невозможно выполнить torch.scatter
операцию, в которой мой тензор индекса больше, чем мой тензор значений. В моем случае у меня есть дубликаты индексов, например, следующий тензор значений a
и тензор индексов idx
:
a = torch.tensor([[0, 1, 0, 0],
[0, 0, 1, 0]])
idx = torch.tensor([[1, 1, 2, 3, 3],
[0, 0, 1, 2, 2]])
a.scatter(-1, idx, 1)
ВОЗВРАТ:
Ошибка времени выполнения: Ожидаемый индекс [2, 5] должен быть меньше, чем self [2, 4], кроме измерения 1, и быть меньше, чем src [2, 4]
Есть ли другой способ достичь этого?
Ответ №1:
Не решение, а обходной путь:
a = torch.tensor([[0, 1, 0, 0],
[0, 0, 1, 0]])
idx = torch.tensor([[1, 1, 2, 3, 3],
[0, 0, 1, 2, 2]])
rows = torch.arange(0, a.size(0))[:,None]
n_col = idx.size(1)
a[rows.repeat(1, n_col), idx] = 1
rows.repeat(1, n_col)
присваивает индекс строки соответствующему индексу столбца в idx
.