#python #performance #numpy #optimization #jax
#python #Производительность #numpy #оптимизация #jax
Вопрос:
У меня есть следующая функция numpy, как показано ниже, которую я пытаюсь оптимизировать с помощью JAX, но по какой-то причине она работает медленнее.
Может кто-нибудь указать, что я могу сделать, чтобы улучшить производительность здесь? Я подозреваю, что это связано с пониманием списка, происходящим для Cg_new, но разбиение этого на части не приводит к дальнейшему увеличению производительности в JAX.
import numpy as np
def testFunction_numpy(C, Mi, C_new, Mi_new):
Wg_new = np.zeros((len(Mi_new[:,0]), len(Mi[0])))
Cg_new = np.zeros((1, len(Mi[0])))
invertCsensor_new = np.linalg.inv(C_new)
Wg_new = np.dot(invertCsensor_new, Mi_new)
Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))]
return C_new, Mi_new, Wg_new, Cg_new
C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)
%timeit testFunction_numpy(C, Mi, C_new, Mi_new)
#1000 loops, best of 3: 1.73 ms per loop
Вот эквивалент JAX:
import jax.numpy as jnp
import numpy as np
import jax
def testFunction_JAX(C, Mi, C_new, Mi_new):
Wg_new = jnp.zeros((len(Mi_new[:,0]), len(Mi[0])))
Cg_new = jnp.zeros((1, len(Mi[0])))
invertCsensor_new = jnp.linalg.inv(C_new)
Wg_new = jnp.dot(invertCsensor_new, Mi_new)
Cg_new = [jnp.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))]
return C_new, Mi_new, Wg_new, Cg_new
C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)
C = jnp.asarray(C)
Mi = jnp.asarray(Mi)
C_new = jnp.asarray(C_new)
Mi_new = jnp.asarray(Mi_new)
jitter = jax.jit(testFunction_JAX)
%timeit jitter(C, Mi, C_new, Mi_new)
#1 loop, best of 3: 4.96 ms per loop
Ответ №1:
Общие соображения по сравнительным сравнениям между JAX и NumPy см. В разделе https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy
Что касается вашего конкретного кода: когда JAX jit-компиляция сталкивается с потоком управления Python, включая понимание списка, она эффективно сглаживает цикл и выполняет полную последовательность операций. Это может привести к медленному времени компиляции jit и неоптимальному коду. К счастью, понимание списка в вашей функции легко выразить в терминах собственной трансляции numpy. Кроме того, вы можете внести два других улучшения:
- нет необходимости пересылать объявления
Wg_new
иCg_new
перед их вычислением - при вычислениях
dot(inv(A), B)
гораздо эффективнее и точнее использоватьnp.linalg.solve
, чем явно вычислять обратное.
Внесение этих трех улучшений в версии numpy и JAX приводит к следующему:
def testFunction_numpy_v2(C, Mi, C_new, Mi_new):
Wg_new = np.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return C_new, Mi_new, Wg_new, Cg_new
@jax.jit
def testFunction_JAX_v2(C, Mi, C_new, Mi_new):
Wg_new = jnp.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return C_new, Mi_new, Wg_new, Cg_new
%timeit testFunction_numpy_v2(C, Mi, C_new, Mi_new)
# 1000 loops, best of 3: 1.11 ms per loop
%timeit testFunction_JAX_v2(C_jax, Mi_jax, C_new_jax, Mi_new_jax)
# 1000 loops, best of 3: 1.35 ms per loop
Обе функции работают немного быстрее, чем раньше, благодаря улучшенной реализации. Однако вы заметите, что JAX здесь все еще медленнее, чем numpy; этого следовало ожидать, потому что для функции такого уровня простоты JAX и numpy эффективно генерируют одну и ту же короткую серию вызовов BLAS и LAPACK, выполняемых на архитектуре процессора. Просто не так много возможностей для улучшения по сравнению с эталонной реализацией numpy, и с такими маленькими массивами накладные расходы JAX очевидны.
Комментарии:
1. Вау, спасибо вам еще раз! Оптимизированная версия numpy действительно очень удобна.
2. Из любопытства я протестировал numba на этом, и это заняло в 10 раз больше времени, чем numpy. О боже.
Ответ №2:
Я протестировал проблему с помощью perfplot в диапазоне размеров проблемы. Результат: jax работает немного быстрее. Причина, по которой jax здесь не превосходит numpy, заключается в том, что он запускается на процессоре (точно так же, как NumPy), и здесь NumPy уже довольно оптимизирован. (Он использует BLAS / LAPACK под капотом.)
Код для воспроизведения графика:
import jax.numpy as jnp
import jax
import numpy as np
import perfplot
def setup(n):
C_new = np.random.rand(n, n)
Mi_new = np.random.rand(n, 8)
return C_new, Mi_new
def testFunction_numpy_v2(C_new, Mi_new):
Wg_new = np.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return Wg_new, Cg_new
@jax.jit
def testFunction_JAX_v2(C_new, Mi_new):
Wg_new = jnp.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return Wg_new, Cg_new
b = perfplot.bench(
setup=setup,
kernels=[testFunction_numpy_v2, testFunction_JAX_v2],
n_range=[2 ** k for k in range(14)],
equality_check=None
)
b.save("out.png")
b.show()