Создание двоичной маски тензора pytorch с использованием определенных значений

#python #pytorch #numpy-ndarray #tensor #binary-matrix

#python #pytorch #numpy-ndarray #тензор #двоичная матрица

Вопрос:

Мне дан 2-D тензор pytorch с целыми числами и 2 целыми числами, которые всегда появляются в каждой строке тензора. Я хочу создать двоичную маску, которая будет содержать 1 между двумя появлениями этих 2 целых чисел, в противном случае 0. Например, если целые числа равны 4 и 2, а 1-D массив равен [1,1,9,4,6,5,1,2,9,9,11,4,3,6,5,2,3,4] , возвращаемая маска будет: [0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0]. Есть ли какой-либо эффективный и быстрый способ вычисления этой маски без итераций?

Ответ №1:

Возможно, немного запутанно, но это работает без итераций. В следующем я предполагаю пример тензора m , к которому я применяю решение, его легче объяснить с помощью этого вместо использования общих обозначений.

 import torch

vals=[2,8]#let's assume those are the constant values that appear in each row

#target tensor
m=torch.tensor([[1., 2., 7., 8., 5.],
    [4., 7., 2., 1., 8.]])

#let's find the indexes of those values
k=m==vals[0]
p=m==vals[1]

v=(k.int() p.int()).bool()
nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],2)

#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1)

#you only need two masks, no matter the size of m. see explanation below
msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q

final_mask=msk_0.int() * msk_1.int()

print(final_mask)
  

и мы получаем

 tensor([[0, 1, 1, 1, 0],
        [0, 0, 1, 1, 1]], dtype=torch.int32)
  

Что касается двух масок mask_0 , и mask_1 в случае, если неясно, что это такое, обратите nz_indexes[:,0] m внимание, что для каждой строки содержит индекс столбца, в котором vals[0] найден, и nz_indexes[:,1] аналогично содержит для каждой строки индекс столбца, m в котором vals[1] найден.

Комментарии:

1. мой код завершается ошибкой в строке {nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],2) }, почему вы используете ненулевое значение над v и берете только первый столбец?

2. @Codevan Хммм… Да, я думаю, что это приведет к сбою, если одно из значений vals[0] или vals[1] встречается более одного раза в одной строке. Мое предположение при написании строки, о которой вы упомянули, заключалось в том, что каждое значение будет происходить только один раз в строке, и так начиная с v.nonzero()[:,0] is the row indexes, I could discard this column as I would already know that v.ненулевой()[0,:]` и v.nonzero()[1:] соответствует строке 0 m и так далее.

3. Я думаю, что возможный способ исправить это — изменить nonzero() с помощью функций, которые вместо этого возвращают первое и последнее вхождения ненулевых элементов.

4. Я соответствующим образом отредактировал свой вопрос. Знаете ли вы, как подогнать свой ответ сейчас? Спасибо!

Ответ №2:

Полностью основываясь на предыдущем решении, вот пересмотренное:

 import torch

vals=[2,8]#let's assume those are the constant values that appear in each row

#target tensor
m=torch.tensor([[1., 2., 7., 8., 5., 2., 6., 5., 8., 4.],
    [4., 7., 2., 1., 8., 2., 6., 5., 6., 8.]])

#let's find the indexes of those values
k=m==vals[0]
p=m==vals[1]

v=(k.int() p.int()).bool()
nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],4)

#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1)

#you only need two masks, no matter the size of m. see explanation below
msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
msk_2=(nz_indexes[:,2].repeat(m.shape[1],1).transpose(0,1))<=q
msk_3=(nz_indexes[:,3].repeat(m.shape[1],1).transpose(0,1))>=q

final_mask=msk_0.int() * msk_1.int()   msk_2.int() * msk_3.int()

print(final_mask)
  

и мы, наконец, получаем

 tensor([[0, 1, 1, 1, 0, 1, 1, 1, 1, 0],
        [0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)
  

Комментарии:

1. На будущее, пожалуйста, не меняйте вопрос после того, как вы уже приняли ответ, это не только грубо, но и заставляет людей, которые вам ответили, выглядеть полными идиотами. Что касается самого решения, было бы лучше найти что-то, что лучше обобщает: что, если в массиве N повторений? Наличие N * 2 массивов mask_i не кажется супер элегантным / эффективным.

2. @Ash прав, и вы должны хотя бы проголосовать за его ответ.