Jax vmap для простого обновления массива

#python #jax

Вопрос:

Я новичок в Jax и работаю над преобразованием чужого кода, который использовал функцию numba «fastmath» и полагался на множество вложенных циклов for без значительной потери производительности. Я пытаюсь воссоздать то же поведение, используя функцию vmap Jax. Однако в настоящее время я много борюсь с некоторыми фундаментальными вопросами. Вот простой пример того, что я пытаюсь векторизовать с помощью vmap:

 import jax.numpy as jnp
from jax import vmap
import jax.ops

a = jnp.arange(20).reshape((4, 5))
b = jnp.arange(5)
c = jnp.arange(4)
d = jnp.zeros(20)
e = jnp.zeros((4, 5))

for i in range(a.shape[0]):
    for j in range(a.shape[1]):
        a = jax.ops.index_add(a, jax.ops.index[i, j], b[j]   c[i])
        d = jax.ops.index_update(d, jax.ops.index[i*a.shape[1]   j], b[j] * c[i])
        e = jax.ops.index_update(e, jax.ops.index[i, j], 2*b[j])

 

Как бы я переписал такой код с помощью vmap? Хотя этот код было бы относительно легко векторизовать вручную, я хочу лучше понять, как работает vmap, и надеюсь, что любой ответ поможет мне. Документы, похоже, сейчас мне не очень помогают. Я действительно ценю любую помощь, которую вы можете мне оказать.

Ответ №1:

Вот как вы можете выполнить примерно те же вычисления, используя vmap :

 from jax import vmap, partial

@partial(vmap, in_axes=(0, None, 0))
@partial(vmap, in_axes=(0, 0, None))
def f(a, b, c):
  return a   b   c, b * c, 2 * b

a, d, e = f(a, b, c)
d = d.ravel()
 

Комментарии:

1. Большое вам спасибо! Этот пример действительно помог мне учиться. Эпический.