Проверьте равенство любых верхних k записей в строках тензора A и argmax в строках тензора B

#python #pytorch #tensor

Вопрос:

Новое в тензорах/pytorch.

У меня есть два 2d тензора, A и B.

A содержит поплавки, представляющие вероятность, присвоенную определенному индексу. B содержит один горячий двоичный вектор в правильном индексе.

 A tensor([[0.1, 0.4, 0.5],  [0.5, 0.4, 0.1],  [0.4, 0.5, 0.1]])  B tensor([[0, 0, 1],  [0, 1, 0],  [0, 0, 1]])  

Я хотел бы найти количество строк, в которых индекс любых значений верхнего k A соответствует индексу one-hot в B. В этом случае k=2.

Моя попытка:

 tops = torch.topk(A, 2, dim=1)  top_idx = tops.indices  top_2_matches = torch.where((torch.any(top_idx, 1) == B.argmax(dim=1)))    

Если все сделано правильно, пример должен возвращать тензор([0, 1]), так как первые 2 строки имеют 2 совпадения сверху, но я получаю (tensor([1]),) в качестве возврата.

Не уверен, в чем я здесь ошибаюсь. Спасибо за любую помощь!

Ответ №1:

Попробуй это:

 top_idx = torch.topk(A, 2, dim=1).indices  row_indicator = (top_idx == B.argmax(dim=1).unsqueeze(dim=1)).any(dim=1)  top_2_matches = torch.arange(len(row_indicator))[row_indicator]  

Например:

 gt;gt;gt; import torch gt;gt;gt; A = torch.tensor([[0.1, 0.4, 0.5], ... [0.5, 0.4, 0.1], ... [0.4, 0.5, 0.1]]) gt;gt;gt; B = torch.tensor([[0, 0, 1], ... [0, 1, 0], ... [0, 0, 1]]) gt;gt;gt; tops = torch.topk(A, 2, dim=1) gt;gt;gt;tops torch.return_types.topk( values=tensor([[0.5000, 0.4000],  [0.5000, 0.4000],  [0.5000, 0.4000]]), indices=tensor([[2, 1],  [0, 1],  [1, 0]])) gt;gt;gt; top_idx = tops.indices gt;gt;gt; top_idx tensor([[2, 1],  [0, 1],  [1, 0]]) gt;gt;gt; index_indicator = top_idx == B.argmax(dim=1).unsqueeze(dim=1) gt;gt;gt; index_indicator tensor([[ True, False],  [False, True],  [False, False]]) gt;gt;gt; row_indicator = index_indicator.any(dim=1) gt;gt;gt; row_indicator tensor([ True, True, False]) gt;gt;gt; top_2_matches = torch.arange(len(row_indicator))[row_indicator] gt;gt;gt; top_2_matches tensor([0, 1])