#python #signature #jit #numba
#python #подпись #jit #numba
Вопрос:
Я просмотрел документацию numba, но ничего не смог найти.
У меня есть функция для jit, которая принимает jitted_function в качестве аргумента. Я хочу сделать быструю компиляцию, добавив подпись, как:
@jit(float64('jit_func.type', int32, int32...))
‘jitted_func.type’ должен быть «типом функции»
когда я делаю:
type(jitted_func)
Я получаю объект CPUDispatcher
спасибо за вашу помощь!
Ответ №1:
Итак, я не уверен, как сгенерировать подпись, которую вы ищете de novo, но если у вас есть образец скомпилированной функции с нужной вам подписью, которую вы можете использовать numba.typeof(...)
для получения предполагаемой подписи, рассмотрим, например:
import numba
@numba.njit(numba.int32(numba.int32))
def x(a):
return a 1
@numba.njit(numba.int32(numba.typeof(x), numba.int32))
def y(fn,a):
return fn(a)
print(y(x,3))
Я проверил, что это нетерпеливая компиляция. Если вы хотите возиться с этим дальше, правильное место для начала numba.core.types.functions
, и Dispatcher
тип обрабатывается специально при компиляции, см. numba.core.typing.context.BaseContext
‘s ._resolve_user_function_type
.
Ответ №2:
Я также ищу решение этой проблемы. К сожалению, предложение @Carbon не работает, потому что тип, возвращаемый numba.typeof
для функции bar
, отличается от типа функции baz
, даже если подписи bar
и baz
совпадают.
Пример:
import numba
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def bar(a):
return 2 * a
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def baz(a):
return 3 * a
@numba.jit(
numba.int32(numba.typeof(bar), numba.int32),
nopython=True,
nogil=True,
)
def foo(fn, a):
return fn(a)
foo(bar, 2)
возвращает 4
foo(baz, 2)
возвращает следующее исключение:
Traceback (most recent call last):
File "test_numba.py", line 33, in <module>
print(foo(baz, 2))
File "<snip>Python38libsite-packagesnumbacoredispatcher.py", line 656, in _explain_matching_error
raise TypeError(msg)
TypeError: No matching definition for argument type(s) type(CPUDispatcher(<function baz at 0x000001DFA8C2D1F0>)), int64
Единственное решение, которое я нашел, — полностью исключить сигнатуру функции foo
и позволить numba разобраться в этом. Я не знаю, какие негативные последствия (если таковые имеются) могут привести к запуску вашего кода.
Пример:
import numba
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def bar(a):
return 2 * a
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def baz(a):
return 3 * a
@numba.jit(
nopython=True,
nogil=True,
)
def foo(fn, a):
return fn(a)
foo(bar, 2)
возвращает 4
foo(baz, 2)
возвращает 6