#python #pytorch #tensor
Вопрос:
Новое в тензорах/pytorch.
У меня есть два 2d тензора, A и B.
A содержит поплавки, представляющие вероятность, присвоенную определенному индексу. B содержит один горячий двоичный вектор в правильном индексе.
A tensor([[0.1, 0.4, 0.5], [0.5, 0.4, 0.1], [0.4, 0.5, 0.1]]) B tensor([[0, 0, 1], [0, 1, 0], [0, 0, 1]])
Я хотел бы найти количество строк, в которых индекс любых значений верхнего k A соответствует индексу one-hot в B. В этом случае k=2.
Моя попытка:
tops = torch.topk(A, 2, dim=1) top_idx = tops.indices top_2_matches = torch.where((torch.any(top_idx, 1) == B.argmax(dim=1)))
Если все сделано правильно, пример должен возвращать тензор([0, 1]), так как первые 2 строки имеют 2 совпадения сверху, но я получаю (tensor([1]),)
в качестве возврата.
Не уверен, в чем я здесь ошибаюсь. Спасибо за любую помощь!
Ответ №1:
Попробуй это:
top_idx = torch.topk(A, 2, dim=1).indices row_indicator = (top_idx == B.argmax(dim=1).unsqueeze(dim=1)).any(dim=1) top_2_matches = torch.arange(len(row_indicator))[row_indicator]
Например:
gt;gt;gt; import torch gt;gt;gt; A = torch.tensor([[0.1, 0.4, 0.5], ... [0.5, 0.4, 0.1], ... [0.4, 0.5, 0.1]]) gt;gt;gt; B = torch.tensor([[0, 0, 1], ... [0, 1, 0], ... [0, 0, 1]]) gt;gt;gt; tops = torch.topk(A, 2, dim=1) gt;gt;gt;tops torch.return_types.topk( values=tensor([[0.5000, 0.4000], [0.5000, 0.4000], [0.5000, 0.4000]]), indices=tensor([[2, 1], [0, 1], [1, 0]])) gt;gt;gt; top_idx = tops.indices gt;gt;gt; top_idx tensor([[2, 1], [0, 1], [1, 0]]) gt;gt;gt; index_indicator = top_idx == B.argmax(dim=1).unsqueeze(dim=1) gt;gt;gt; index_indicator tensor([[ True, False], [False, True], [False, False]]) gt;gt;gt; row_indicator = index_indicator.any(dim=1) gt;gt;gt; row_indicator tensor([ True, True, False]) gt;gt;gt; top_2_matches = torch.arange(len(row_indicator))[row_indicator] gt;gt;gt; top_2_matches tensor([0, 1])