Почему эта функция медленнее в JAX против numpy?

#python #performance #numpy #optimization #jax

#python #Производительность #numpy #оптимизация #jax

Вопрос:

У меня есть следующая функция numpy, как показано ниже, которую я пытаюсь оптимизировать с помощью JAX, но по какой-то причине она работает медленнее.

Может кто-нибудь указать, что я могу сделать, чтобы улучшить производительность здесь? Я подозреваю, что это связано с пониманием списка, происходящим для Cg_new, но разбиение этого на части не приводит к дальнейшему увеличению производительности в JAX.

 import numpy as np 

def testFunction_numpy(C, Mi, C_new, Mi_new):
    Wg_new = np.zeros((len(Mi_new[:,0]), len(Mi[0])))
    Cg_new = np.zeros((1, len(Mi[0])))
    invertCsensor_new = np.linalg.inv(C_new)

    Wg_new = np.dot(invertCsensor_new, Mi_new)
    Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))] 

    return C_new, Mi_new, Wg_new, Cg_new

C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)

%timeit testFunction_numpy(C, Mi, C_new, Mi_new)
#1000 loops, best of 3: 1.73 ms per loop
  

Вот эквивалент JAX:

 import jax.numpy as jnp
import numpy as np
import jax

def testFunction_JAX(C, Mi, C_new, Mi_new):
    Wg_new = jnp.zeros((len(Mi_new[:,0]), len(Mi[0])))
    Cg_new = jnp.zeros((1, len(Mi[0])))
    invertCsensor_new = jnp.linalg.inv(C_new)

    Wg_new = jnp.dot(invertCsensor_new, Mi_new)
    Cg_new = [jnp.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))] 

    return C_new, Mi_new, Wg_new, Cg_new

C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)

C = jnp.asarray(C)
Mi = jnp.asarray(Mi)
C_new = jnp.asarray(C_new)
Mi_new = jnp.asarray(Mi_new)

jitter = jax.jit(testFunction_JAX) 

%timeit jitter(C, Mi, C_new, Mi_new)
#1 loop, best of 3: 4.96 ms per loop
  

Ответ №1:

Общие соображения по сравнительным сравнениям между JAX и NumPy см. В разделе https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy

Что касается вашего конкретного кода: когда JAX jit-компиляция сталкивается с потоком управления Python, включая понимание списка, она эффективно сглаживает цикл и выполняет полную последовательность операций. Это может привести к медленному времени компиляции jit и неоптимальному коду. К счастью, понимание списка в вашей функции легко выразить в терминах собственной трансляции numpy. Кроме того, вы можете внести два других улучшения:

  • нет необходимости пересылать объявления Wg_new и Cg_new перед их вычислением
  • при вычислениях dot(inv(A), B) гораздо эффективнее и точнее использовать np.linalg.solve , чем явно вычислять обратное.

Внесение этих трех улучшений в версии numpy и JAX приводит к следующему:

 def testFunction_numpy_v2(C, Mi, C_new, Mi_new):
    Wg_new = np.linalg.solve(C_new, Mi_new)
    Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
    return C_new, Mi_new, Wg_new, Cg_new

@jax.jit
def testFunction_JAX_v2(C, Mi, C_new, Mi_new):
    Wg_new = jnp.linalg.solve(C_new, Mi_new)
    Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
    return C_new, Mi_new, Wg_new, Cg_new

%timeit testFunction_numpy_v2(C, Mi, C_new, Mi_new)
# 1000 loops, best of 3: 1.11 ms per loop
%timeit testFunction_JAX_v2(C_jax, Mi_jax, C_new_jax, Mi_new_jax)
# 1000 loops, best of 3: 1.35 ms per loop
  

Обе функции работают немного быстрее, чем раньше, благодаря улучшенной реализации. Однако вы заметите, что JAX здесь все еще медленнее, чем numpy; этого следовало ожидать, потому что для функции такого уровня простоты JAX и numpy эффективно генерируют одну и ту же короткую серию вызовов BLAS и LAPACK, выполняемых на архитектуре процессора. Просто не так много возможностей для улучшения по сравнению с эталонной реализацией numpy, и с такими маленькими массивами накладные расходы JAX очевидны.

Комментарии:

1. Вау, спасибо вам еще раз! Оптимизированная версия numpy действительно очень удобна.

2. Из любопытства я протестировал numba на этом, и это заняло в 10 раз больше времени, чем numpy. О боже.

Ответ №2:

Я протестировал проблему с помощью perfplot в диапазоне размеров проблемы. Результат: jax работает немного быстрее. Причина, по которой jax здесь не превосходит numpy, заключается в том, что он запускается на процессоре (точно так же, как NumPy), и здесь NumPy уже довольно оптимизирован. (Он использует BLAS / LAPACK под капотом.)

введите описание изображения здесь

Код для воспроизведения графика:

 import jax.numpy as jnp
import jax
import numpy as np
import perfplot


def setup(n):
    C_new = np.random.rand(n, n)
    Mi_new = np.random.rand(n, 8)
    return C_new, Mi_new


def testFunction_numpy_v2(C_new, Mi_new):
    Wg_new = np.linalg.solve(C_new, Mi_new)
    Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
    return Wg_new, Cg_new


@jax.jit
def testFunction_JAX_v2(C_new, Mi_new):
    Wg_new = jnp.linalg.solve(C_new, Mi_new)
    Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
    return Wg_new, Cg_new


b = perfplot.bench(
    setup=setup,
    kernels=[testFunction_numpy_v2, testFunction_JAX_v2],
    n_range=[2 ** k for k in range(14)],
    equality_check=None
)
b.save("out.png")
b.show()