Почему jax.numpy.dot() работает медленнее, чем numpy.dot() на CPU?

#python #numpy

#python #numpy

Вопрос:

Я хочу использовать JAX для ускорения моего кода numpy на CPU, а затем на GPU. Вот мой пример кода, работающий на моем локальном компьютере (только CPU):

 import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time

size = 3000

key = random.PRNGKey(0)
x =  random.normal(key, (size,size), dtype=jnp.float64)

start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))

x2=np.random.normal((size,size))

start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))
  

Я получил предупреждение, и временные затраты следующие:

 /.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: 
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s
  

Я сделал что-то не так здесь? Должен ли JAX также ускорять код numpy на процессорах?

Комментарии:

1. Велика вероятность, что numpy использует (Open-) BLAS , и оптимизировать особо нечего np.dot() .

2. @sascha Но не имеет смысла, что JAX намного медленнее, чем NumPy. Я еще не выяснил причину.

3. Это меня не удивляет. Точка (векторная или matmul; не имеет значения) полностью закодирована вручную для любого типа процессорной арки, и вы не превзойдете это с помощью автоматических компиляторов. Без особых знаний о JAX, вероятно, речь идет о планировании, оптимизации временных ресурсов и некоторых других вещах -> что приводит к отличному коду, когда несколько «ядер» объединены . Но точка настолько элементарна, что нет никаких шансов добраться до нее. Дополнительное замечание: я упомянул openblas: но в некоторых версиях используется Intels MKL: вы ожидаете, что некоторые автоматические компиляторы превзойдут код matmul, закодированный вручную (разработчиками Intel) на вашем (возможно) процессоре Intel?

4. Также читайте: это

5. Возможно, numpy также намного быстрее, потому что форма x есть (3000, 3000) , а форма x2 есть (2,) 🙂

Ответ №1:

Вероятно, существуют различия в производительности между Jax и Numpy, но в исходном сообщении разница во времени в основном сводится к ошибке при создании массива. Массив, используемый Jax, имеет форму 3000×3000, тогда как массив, используемый Numpy, представляет собой одномерный массив длиной 2. Первым аргументом numpy.random.normal является loc (т. Е. Среднее значение гауссова, из которого выполняется выборка). Аргумент ключевого слова size= следует использовать для указания формы массива.

 numpy.random.normal(loc=0.0, scale=1.0, size=None)
  

После внесения этого изменения производительность между Jax и Numpy меньше отличается.

 import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size), dtype=jnp.float64)

start = time.time()
test = jnp.dot(x, x.T).block_until_ready()
print("Time of jnp: {:0.4f} s".format(time.time() - start))

x2 = np.random.normal(size=(size, size)).astype(np.float64)

start = time.time()
test2 = np.dot(x2, x2.T)
print("Time of np: {:0.4f} s".format(time.time() - start))
  

Результат одного запуска равен

 Time of jnp: 2.3315 s
Time of np: 2.8811 s
  

При измерении временной производительности следует собирать несколько запусков, потому что производительность функции — это разброс во времени, а не одно значение. Это можно сделать с помощью timeit.timeit функции стандартной библиотеки Python или %timeit магии в IPython и Jupyter Notebook.

 import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
xjnp = jax.random.normal(key, shape=(size, size), dtype=jnp.float64)
xnp = np.random.normal(size=(size, size)).astype(np.float64)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.03 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 3.41 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

xjnp = xjnp.astype(jnp.float32)
xnp = xnp.astype(np.float32)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.05 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 1.73 s ± 383 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  

Похоже, что в Numpy оптимизирована точечная операция для 32-разрядных чисел с плавающей точкой.

Комментарии:

1. Большое спасибо. Единственное, что я нашел по-другому, это производительность по времени. На моем ноутбуке (8 ядер) numpy.dot для запуска кода (dtype = float64) требуется 0,3176 с, что быстрее, чем jax.numpy.dot (0,4529 с). Если я изменю dtype на float32, numpy.dot будет еще быстрее, что займет всего 0,2112 с, но jax.numpy.dot по-прежнему занимает 0,4537 с.