Как выполнить срез из многомерного массива вдоль оси в NumPy?

#python #numpy

#python #numpy

Вопрос:

Предположим, у меня есть NumPy ndarray x формы (n,) higher_dims , где n — положительное значение int и higher_dims — кортеж положительных целых чисел любой длины. То есть n — это размер первой оси, и осей может быть произвольно много.

Предположим, у меня также есть ndarray indices формы (k,) higher_dims , где k — положительное значение int. То есть, indices имеет ту же форму, что и x за исключением, возможно, первой оси. Предположим, что каждая запись indices является значением int между 0 и n - 1 .

Я хочу создать массив, y который имеет ту же форму, что и indices , и который удовлетворяет

 y[i, ...] = x[indices[i, ...], ...]
  

для каждого i между 0 и n - 1 . Здесь ... обозначает произвольную комбинацию индексов для остальных осей, а не объект с многоточием.

Например, вот как я мог бы создать, чтобы y if x был трехмерным, используя for-циклы:

 import numpy as np

x = np.arange(24).reshape((4, 2, 3))
print('x =', x, sep='n')

indices = np.asarray([[[1, 0, 1], [2, 1, 2]], [[3, 1, 2], [0, 0, 1]]])
print('indices =', indices, sep='n')

y = np.empty(indices.shape, dtype=x.dtype)
for i in range(indices.shape[0]):
    for j in range(indices.shape[1]):
        for k in range(indices.shape[2]):
            y[i, j, k] = x[indices[i, j, k], j, k]  # Defining property of y
print('y =', y, sep='n')
  

Вывод:

 x =
[[[ 0  1  2]
  [ 3  4  5]]
 [[ 6  7  8]
  [ 9 10 11]]
 [[12 13 14]
  [15 16 17]]
 [[18 19 20]
  [21 22 23]]]
indices =
[[[1 0 1]
  [2 1 2]]
 [[3 1 2]
  [0 0 1]]]
y =
[[[ 6  1  8]
  [15 10 17]]
 [[18  7 14]
  [ 3  4 11]]]
  

Я ищу функцию или трюк с индексированием для достижения такого поведения в целом (для ndarrays произвольной размерности), по возможности без циклов Python.

Комментарии:

1. Просто используйте np.take_along_axis(x,indices,axis=0) .

2. @Divakar спасибо, это именно то, что мне было нужно! Я знал о take() , но не о take_along_axis() 🙂