Каков самый быстрый способ выбора подмножества матрицы JAX?

#python #jax

#python #jax

Вопрос:

Допустим, у меня есть 2D-матрица, и я хочу отобразить ее значения в виде гистограммы. Для этого мне нужно сделать что-то вроде:

 list_1d = matrix_2d.reshape((-1,)).tolist()
  

А затем используйте список для построения гистограммы. Пока все хорошо, просто в исходной матрице есть элементы, которые я хочу исключить. Для простоты предположим, что у меня есть такой список:

 exclude = [(2, 5), (3, 4), (6, 1)]
  

Итак, в list_1d матрице должны быть все элементы без элементов, на которые указывает exclude (элементами exclude являются индексы строк и столбцов).

И, кстати, matrix_2d это массив JAX, что означает, что его содержимое находится в графическом процессоре.

Ответ №1:

Один из способов сделать это — создать массив масок, который вы используете для выбора нужного подмножества массива. Операция индексации маски возвращает 1D копию выбранных данных:

 import jax.numpy as jnp
from jax import random
matrix_2d = random.uniform(random.PRNGKey(0), (10, 10))
exclude = [(2, 5), (3, 4), (6, 1)]

ind = tuple(jnp.array(exclude).T)
mask = jnp.ones_like(matrix_2d, dtype=bool).at[ind].set(False)

list_1d = matrix_2d[mask].tolist()
len(list_1d)
# 97