numba и numpy.expand_dims

#python #numba

#python #numba

Вопрос:

Я переписываю некоторые из своих функций, чтобы они подходили для Numba. Теперь у меня есть функция, которую я вызываю несколько раз в моем скрипте с входными массивами разных размеров.

 def FormHistMatrix2(x,Whc,Lm):
    if x.ndim == 1:
       x = np.expand_dims(x,axis=1)
    [N,Ncells] = x.shape
  

Это начало моей функции, и Numba выдает следующую ошибку:

 TypingError: Cannot unify array(float64, 2d, A) and array(float64, 3d, A) for 'x', defined at C:/Users/DNP_Student_3/Documents/Python Scripts/GCFuncsTests.py (332)
  

В этом случае ‘x’ — это 2D-массив, но в других случаях это может быть одномерный массив.
Так что, Numba не нравится цикл if? Или что здесь происходит?

Ответ №1:

В Numba, в отличие от стандартного python, переменная не может изменять свой тип во время выполнения функции. Вы должны иметь возможность присвоить результат вызова np.expand_dims другой переменной, и это будет работать. Это нормально, если иногда x равно 1d, а иногда 2d, если существует согласованность типов всех переменных во время выполнения функции.

Комментарии:

1. Да, я знаю, что numba не любит изменять типы переменных, но я ожидал, что он пропустит цикл if, когда x будет 2-D, но, видимо, это не так. Спасибо!

Ответ №2:

То, что сказал Джошадель, в целом верно, но проблема в этом случае заключается в том, что вам нужна другая реализация / специализация вашей функции в зависимости от типа ввода.

Для этого случая в Numba есть @generated_jit -decorator.

В вашем случае вам нужно было бы написать специализированную функцию expand-dims, которая зависит от размеров входного массива:

 import numba as nb
@nb.generated_jit(nopython=True)
def nb_expander(x):
    if x.ndim == 1:
        return lambda x: np.expand_dims(x, axis=1)
    else:
        return lambda x: x
  

Эта функция должна вызываться из вашей другой функции:

 @nb.njit
def FormHistMatrix2(x, Whc, Lm):
    x = nb_expander(x)
    [N, Ncells] = x.shape
  

Теперь это будет работать для x размеров 1 и 2. Для x.ndim==3 вам также необходимо реализовать аналогичный метод для фигуры.