Почему torch.scatter требует для индексов меньшую форму, чем значения?

#pytorch

Вопрос:

Аналогичный вопрос уже задавался здесь, но я думаю, что решение не подходит для моего случая.

Я просто удивляюсь, почему невозможно выполнить torch.scatter операцию, в которой мой тензор индекса больше, чем мой тензор значений. В моем случае у меня есть дубликаты индексов, например, следующий тензор значений a и тензор индексов idx :

 a = torch.tensor([[0, 1, 0, 0],
                  [0, 0, 1, 0]])

idx = torch.tensor([[1, 1, 2, 3, 3],
                    [0, 0, 1, 2, 2]])
 

a.scatter(-1, idx, 1) ВОЗВРАТ:

Ошибка времени выполнения: Ожидаемый индекс [2, 5] должен быть меньше, чем self [2, 4], кроме измерения 1, и быть меньше, чем src [2, 4]

Есть ли другой способ достичь этого?

Ответ №1:

Не решение, а обходной путь:

 a = torch.tensor([[0, 1, 0, 0],
                  [0, 0, 1, 0]])

idx = torch.tensor([[1, 1, 2, 3, 3],
                    [0, 0, 1, 2, 2]])

rows = torch.arange(0, a.size(0))[:,None]
n_col = idx.size(1)
a[rows.repeat(1, n_col), idx] = 1
 

rows.repeat(1, n_col) присваивает индекс строки соответствующему индексу столбца в idx .