#python #numpy
#python #numpy
Вопрос:
Предположим, у меня есть массив numpy следующим образом:
>> custom_array = np.array([[[0.1, 0.4, 0.4],
[0.3, 0.2, 0.7]],
[[0.9, 0.4, 0.2],
[0.5, 0.1, 0.1]]])
>> custom_array.shape
(2,2,3)
Приведенный выше массив представляет таблицу вероятностей:
| X1 | X2 | X3 | P |
|---- |---- |---- |----- |
| 0 | 0 | 0 | 0.1 |
| 0 | 0 | 1 | 0.4 |
| 0 | 0 | 2 | 0.4 |
| 0 | 1 | 0 | 0.3 |
| 0 | 1 | 1 | 0.2 |
| 0 | 1 | 2 | 0.7 |
| 1 | 0 | 0 | 0.9 |
| 1 | 0 | 1 | 0.4 |
| 1 | 0 | 2 | 0.2 |
| 1 | 1 | 0 | 0.5 |
| 1 | 1 | 1 | 0.1 |
| 1 | 1 | 2 | 0.1 |
Я хотел бы получить индексы вдоль оси 0 и 2, которые максимизируют значение вдоль оси = 1 (вторая ось). Мой ожидаемый результат должен быть:
(1,0) => Потому что эта комбинация X1 и X3 максимизирует значение P при X2 = 0
(0,2) => Потому что эта комбинация X1 и X3 максимизирует значение P при X2 = 1
Я попробовал следующую команду:
>>> np.unravel_index(custom_array.argmax(), custom_array.shape)
(1,0,0)
Это правильный ответ, но я не могу понять, как изменить эту команду, чтобы получить оба результата, упомянутых выше.
Я был бы признателен за любые подсказки, которые может предоставить сообщество. Спасибо!
Ответ №1:
Чтобы получить (например, распечатать) индексы максимального количества элементов, используйте следующий код:
for i in range(custom_array.shape[1]):
tbl = np.take(custom_array, i, axis=1)
ind = np.unravel_index(np.argmax(tbl), tbl.shape)
print(f'{i}: {ind}')
Подробные сведения:
for i in ...
— перебор значений X2,tbl
— фрагмент вашего исходного массива с определенным значением X2,np.argmax(tbl)
— индекс максимального элемента в сплющенном массиве,np.unravel_index
— преобразовать в индексы в базовом массиве (tbl).
Для вашего образца данных результат:
0: (1, 0)
1: (0, 2)
Ответ №2:
Чтобы найти индексы в срезе, где X2 равно 0, используйте:
from numpy import array, where
array(where(custom_array[:, 0, :] == custom_array[:, 0, :].max())).squeeze()
Чтобы найти индексы в срезе, где X2 равно 1, используйте:
array(where(custom_array[:, 1, :] == custom_array[:, 1, :].max())).squeeze()
Чтобы найти массив с любым количеством значений для X2, используйте:
[array(where(custom_array[:, i, :] == custom_array[:, i, :].max())).squeeze()
for i in range(custom_array.shape[1])]