Как я могу найти n индексов минимальных элементов для каждой строки с помощью numpy?

#python #numpy

Вопрос:

Например:

 n = 2
p1 = np.asarray([[20, 30, 10],
                 [10, 20, 30],
                 [30, 20, 10]])
 

В результате я хочу:

 [ [0, 0, 2],
  [1, 0, 1],
  [2, 1, 2] ]
            
 

Первое число в каждой строке-это просто номер строки в p1. Оставшиеся n чисел являются индексами минимальных элементов строки. Так:

 [0, 0, 2]
 # 0 is the index of the first row in p1.
 # (0, 2 are the indices of minimum elements of the row)


[1, 0, 1]
# 1 is the index of the second row in p1.
# (0, 1 are the indices of minimum elements of the row)

[2, 1, 2]
# 2 is the index of the third row in p1.
# (1, 2 are the indices of minimum elements of the row)
 

Большое спасибо!!!

Комментарии:

1. Что minimum elements такое?

Ответ №1:

Используйте np.argpartition , чтобы найти два верхних минимума:

 import numpy as np

n = 2
p1 = np.asarray([[20, 30, 10],
                 [10, 20, 30],
                 [30, 20, 10]])

pos = np.argpartition(p1, axis=1, kth=2)

res = np.hstack([np.arange(3)[:, None], np.sort(pos[:, :2])])
print(res)
 

Выход

 [[0 0 2]
 [1 0 1]
 [2 1 2]]
 

Как только вы найдете минимумы, используйте np.hstack их для объединения индекса строк и к нему.