Подпись Numba для jitted функции в качестве аргумента

#python #signature #jit #numba

#python #подпись #jit #numba

Вопрос:

Я просмотрел документацию numba, но ничего не смог найти.

У меня есть функция для jit, которая принимает jitted_function в качестве аргумента. Я хочу сделать быструю компиляцию, добавив подпись, как:

 @jit(float64('jit_func.type', int32, int32...))
  

‘jitted_func.type’ должен быть «типом функции»

когда я делаю:

 type(jitted_func)
  

Я получаю объект CPUDispatcher

спасибо за вашу помощь!

Ответ №1:

Итак, я не уверен, как сгенерировать подпись, которую вы ищете de novo, но если у вас есть образец скомпилированной функции с нужной вам подписью, которую вы можете использовать numba.typeof(...) для получения предполагаемой подписи, рассмотрим, например:

 import numba

@numba.njit(numba.int32(numba.int32))
def x(a):
    return a 1

@numba.njit(numba.int32(numba.typeof(x), numba.int32))
def y(fn,a):
    return fn(a)
    
print(y(x,3))
  

Я проверил, что это нетерпеливая компиляция. Если вы хотите возиться с этим дальше, правильное место для начала numba.core.types.functions , и Dispatcher тип обрабатывается специально при компиляции, см. numba.core.typing.context.BaseContext ‘s ._resolve_user_function_type .

Ответ №2:

Я также ищу решение этой проблемы. К сожалению, предложение @Carbon не работает, потому что тип, возвращаемый numba.typeof для функции bar , отличается от типа функции baz , даже если подписи bar и baz совпадают.

Пример:

 import numba 

@numba.jit(
    numba.int32(numba.int32),
    nopython=True,
    nogil=True,
)
def bar(a):

    return 2 * a

@numba.jit(
    numba.int32(numba.int32),
    nopython=True,
    nogil=True,
)
def baz(a):

    return 3 * a

@numba.jit(
    numba.int32(numba.typeof(bar), numba.int32),
    nopython=True,
    nogil=True,
)
def foo(fn, a):

    return fn(a)
  

foo(bar, 2) возвращает 4

foo(baz, 2) возвращает следующее исключение:

 Traceback (most recent call last):
  File "test_numba.py", line 33, in <module>
    print(foo(baz, 2))
  File "<snip>Python38libsite-packagesnumbacoredispatcher.py", line 656, in _explain_matching_error
    raise TypeError(msg)
TypeError: No matching definition for argument type(s) type(CPUDispatcher(<function baz at 0x000001DFA8C2D1F0>)), int64
  

Единственное решение, которое я нашел, — полностью исключить сигнатуру функции foo и позволить numba разобраться в этом. Я не знаю, какие негативные последствия (если таковые имеются) могут привести к запуску вашего кода.

Пример:

 import numba 

@numba.jit(
    numba.int32(numba.int32),
    nopython=True,
    nogil=True,
)
def bar(a):

    return 2 * a

@numba.jit(
    numba.int32(numba.int32),
    nopython=True,
    nogil=True,
)
def baz(a):

    return 3 * a

@numba.jit(
    nopython=True,
    nogil=True,
)
def foo(fn, a):

    return fn(a)
  

foo(bar, 2) возвращает 4

foo(baz, 2) возвращает 6