Максимум по оси в numpy

#python #numpy

#python #numpy

Вопрос:

Предположим, у меня есть массив numpy следующим образом:

 >> custom_array = np.array([[[0.1, 0.4, 0.4],
        [0.3, 0.2, 0.7]],

       [[0.9, 0.4, 0.2],
        [0.5, 0.1, 0.1]]])
>> custom_array.shape
(2,2,3)
  

Приведенный выше массив представляет таблицу вероятностей:

 | X1    | X2    | X3    | P     |
|----   |----   |----   |-----  |
| 0     | 0     | 0     | 0.1   |
| 0     | 0     | 1     | 0.4   |
| 0     | 0     | 2     | 0.4   |
| 0     | 1     | 0     | 0.3   |
| 0     | 1     | 1     | 0.2   |
| 0     | 1     | 2     | 0.7   |
| 1     | 0     | 0     | 0.9   |
| 1     | 0     | 1     | 0.4   |
| 1     | 0     | 2     | 0.2   |
| 1     | 1     | 0     | 0.5   |
| 1     | 1     | 1     | 0.1   |
| 1     | 1     | 2     | 0.1   |
  

Я хотел бы получить индексы вдоль оси 0 и 2, которые максимизируют значение вдоль оси = 1 (вторая ось). Мой ожидаемый результат должен быть:

(1,0) => Потому что эта комбинация X1 и X3 максимизирует значение P при X2 = 0

(0,2) => Потому что эта комбинация X1 и X3 максимизирует значение P при X2 = 1

Я попробовал следующую команду:

 >>> np.unravel_index(custom_array.argmax(), custom_array.shape)
(1,0,0)
  

Это правильный ответ, но я не могу понять, как изменить эту команду, чтобы получить оба результата, упомянутых выше.

Я был бы признателен за любые подсказки, которые может предоставить сообщество. Спасибо!

Ответ №1:

Чтобы получить (например, распечатать) индексы максимального количества элементов, используйте следующий код:

 for i in range(custom_array.shape[1]):
    tbl = np.take(custom_array, i, axis=1)
    ind = np.unravel_index(np.argmax(tbl), tbl.shape)
    print(f'{i}: {ind}')
  

Подробные сведения:

  • for i in ... — перебор значений X2,
  • tbl — фрагмент вашего исходного массива с определенным значением X2,
  • np.argmax(tbl) — индекс максимального элемента в сплющенном массиве,
  • np.unravel_index — преобразовать в индексы в базовом массиве (tbl).

Для вашего образца данных результат:

 0: (1, 0)
1: (0, 2)
  

Ответ №2:

Чтобы найти индексы в срезе, где X2 равно 0, используйте:

 from numpy import array, where

array(where(custom_array[:, 0, :] == custom_array[:, 0, :].max())).squeeze()
  

Чтобы найти индексы в срезе, где X2 равно 1, используйте:

 array(where(custom_array[:, 1, :] == custom_array[:, 1, :].max())).squeeze()
  

Чтобы найти массив с любым количеством значений для X2, используйте:

 [array(where(custom_array[:, i, :] == custom_array[:, i, :].max())).squeeze() 
 for i in range(custom_array.shape[1])]