Кондиционирование элементов матрицы в JIT-ted функции

#python #jax

Вопрос:

У меня есть функция, которая выглядит следующим образом

         @jax.jit
        def f(R):
            tr = jnp.trace(R)

            r00 = R[0, 0]
            r01 = R[0, 1]
            r02 = R[0, 2]
            r10 = R[1, 0]
            r11 = R[1, 1]
            r12 = R[1, 2]
            r20 = R[2, 0]
            r21 = R[2, 1]
            r22 = R[2, 2]

            condw = tr > 0
            condx = (r00 > r11) and (r00 > r22)
            condy = (r11 > r22)
            # ... do some more things based on the conditions

 

где R 3×3 DeviceArray . Когда я пытаюсь выполнить эту функцию, как показано выше, я получаю следующую ошибку:

 File "/path/to/my/file", line 90, in f
    condx = (r00 > r11) and (r00 > r22)
  File "/Users/me/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/core.py", line 544, in __bool__
    def __bool__(self): return self.aval._bool(self)
  File "/Users/me/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/core.py", line 989, in error
    raise ConcretizationTypeError(arg, fname_context)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at /path/to/my/file:75 for jit, this concrete value was not available in Python because it depends on the value of the argument 'R'.
 

Я не совсем уверен, что не так с вычислением этого логического значения, которое предотвращает прерывание этой функции.

         condx = (r00 > r11) and (r00 > r22)
 

Любые подсказки будут высоко оценены — спасибо!

Ответ №1:

Начиная с #3761, используйте побитовые операторы вместо логических операторов.

Это работает.

 condx = (r00 > r11) amp; (r00 > r22)