#python #jax
Вопрос:
У меня есть функция, которая выглядит следующим образом
@jax.jit
def f(R):
tr = jnp.trace(R)
r00 = R[0, 0]
r01 = R[0, 1]
r02 = R[0, 2]
r10 = R[1, 0]
r11 = R[1, 1]
r12 = R[1, 2]
r20 = R[2, 0]
r21 = R[2, 1]
r22 = R[2, 2]
condw = tr > 0
condx = (r00 > r11) and (r00 > r22)
condy = (r11 > r22)
# ... do some more things based on the conditions
где R
3×3 DeviceArray
. Когда я пытаюсь выполнить эту функцию, как показано выше, я получаю следующую ошибку:
File "/path/to/my/file", line 90, in f
condx = (r00 > r11) and (r00 > r22)
File "/Users/me/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/core.py", line 544, in __bool__
def __bool__(self): return self.aval._bool(self)
File "/Users/me/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/core.py", line 989, in error
raise ConcretizationTypeError(arg, fname_context)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function f at /path/to/my/file:75 for jit, this concrete value was not available in Python because it depends on the value of the argument 'R'.
Я не совсем уверен, что не так с вычислением этого логического значения, которое предотвращает прерывание этой функции.
condx = (r00 > r11) and (r00 > r22)
Любые подсказки будут высоко оценены — спасибо!
Ответ №1:
Начиная с #3761, используйте побитовые операторы вместо логических операторов.
Это работает.
condx = (r00 > r11) amp; (r00 > r22)