Фильтровать массив Numpy с необязательным аргументом

#python #numpy #indexing

#python #numpy #индексирование

Вопрос:

Я создаю функцию, которая должна подготавливать мои данные в зависимости от входных данных. Переменная x_imp содержит индексы, для которых важны функции. Однако иногда мне все еще нужны все функции, поэтому, если ‘x_imp = None’, ничего не должно произойти.

Мое решение было таким (это не вся функция, а только входные данные):

 def get_train_data(x_cat, x_num,x_imp = None):
        x_cat = x_cat[:,x_imp]
        x_num = x_num[:,x_imp]
    return x_train
 

Но это меняет форму данных.
Например, если data.shape = (4, 5) тогда data[:,None].shape = (4, 1, 5)

Как мне избежать этой проблемы?

Ответ №1:

Это происходит потому , что нарезка по None является псевдонимом для np.newaxis . Есть ли причина не просто добавлять явное if утверждение?

 def get_train_data(x_cat, x_num,x_imp = None):
    if x_imp is not None:
        x_cat = x_cat[:,x_imp]
        x_num = x_num[:,x_imp]
    return x_train