Приведите функцию потерь наименьших квадратов в Jax

#python #jax

Вопрос:

У меня есть простая функция потерь, которая выглядит так

         def loss(r, x, y):
            resid = f(r, x) - y
            return jnp.mean(jnp.square(resid))
 

Я хотел бы оптимизировать параметр r и использовать некоторые статические параметры x , а y также вычислить остаток. Все параметры, о которых идет речь, таковы DeviceArrays .

Чтобы справиться с этим, я попытался сделать следующее

         @partial(jax.jit, static_argnums=(1, 2))
        def loss(r, x, y):
            resid = f(r, x) - y
            return jnp.mean(jnp.square(resid))
 

но я получаю эту ошибку

 jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'numpy.ndarray'> for function loss is non-hashable.
 

Я понимаю, что из #6233, что это сделано специально, но мне было интересно, как здесь обойти проблему, так как это похоже на очень распространенный случай использования, когда у вас есть несколько фиксированных (входных, выходных) пар обучающих данных и некоторая свободная переменная.

Спасибо за любые советы!

РЕДАКТИРОВАТЬ: это ошибка, которую я получаю, когда просто пытаюсь использовать jax.jit

 jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function loss at /path/to/my/script:9 for jit, this concrete value was not available in Python because it depends on the value of the argument 'r'.`
 

Ответ №1:

Похоже, вы думаете о статических аргументах как о «значениях, которые не меняются между вычислениями». В JIT JAX статические аргументы лучше рассматривать как «хешируемые константы времени компиляции». В вашем случае у вас нет хешируемых констант времени компиляции; у вас есть массивы, поэтому вы можете просто выполнить JIT-компиляцию без статических аргументов:

 @jit
def loss(r, x, y):
    resid = f(r, x) - y
    return jnp.mean(jnp.square(resid))
 

Если вы действительно хотите, чтобы механизм JAX знал, что ваши массивы постоянны, вы можете сделать это, передав их через закрытие или частичное закрытие; например:

 from functools import partial

def loss(r, x, y):
    resid = f(r, x) - y
    return jnp.mean(jnp.square(resid))
loss = jit(partial(loss, x=x, y=y))
 

Однако для того типа вычислений, которые вы выполняете, где константами являются массивы, управляемые функциями массива JAX, эти два подхода приводят к в основном одинаковому коду XLA с пониженным значением, поэтому вы можете также использовать более простой.

Комментарии:

1. Спасибо за подробный ответ. Однако я получаю ошибку, когда пытаюсь запустить его с помощью just @jit . (в ОП для лучшей читаемости). Следуя трассировке стека, похоже , что что-то не так f , но мне было интересно, могу ли я получить дополнительную информацию о том, что означает этот тип ошибки. Переход по ссылке, которую он предоставляет, приводит меня к аннотации статических аргументов.

2. На самом деле, я сделаю отдельный пост об этом, так как это не имеет отношения к операции. Спасибо за вашу помощь!