#python #pytorch #numpy-ndarray #tensor #binary-matrix
#python #pytorch #numpy-ndarray #тензор #двоичная матрица
Вопрос:
Мне дан 2-D тензор pytorch с целыми числами и 2 целыми числами, которые всегда появляются в каждой строке тензора. Я хочу создать двоичную маску, которая будет содержать 1 между двумя появлениями этих 2 целых чисел, в противном случае 0. Например, если целые числа равны 4 и 2, а 1-D массив равен [1,1,9,4,6,5,1,2,9,9,11,4,3,6,5,2,3,4]
, возвращаемая маска будет: [0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0].
Есть ли какой-либо эффективный и быстрый способ вычисления этой маски без итераций?
Ответ №1:
Возможно, немного запутанно, но это работает без итераций. В следующем я предполагаю пример тензора m
, к которому я применяю решение, его легче объяснить с помощью этого вместо использования общих обозначений.
import torch
vals=[2,8]#let's assume those are the constant values that appear in each row
#target tensor
m=torch.tensor([[1., 2., 7., 8., 5.],
[4., 7., 2., 1., 8.]])
#let's find the indexes of those values
k=m==vals[0]
p=m==vals[1]
v=(k.int() p.int()).bool()
nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],2)
#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1)
#you only need two masks, no matter the size of m. see explanation below
msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
final_mask=msk_0.int() * msk_1.int()
print(final_mask)
и мы получаем
tensor([[0, 1, 1, 1, 0],
[0, 0, 1, 1, 1]], dtype=torch.int32)
Что касается двух масок mask_0
, и mask_1
в случае, если неясно, что это такое, обратите nz_indexes[:,0]
m
внимание, что для каждой строки содержит индекс столбца, в котором vals[0]
найден, и nz_indexes[:,1]
аналогично содержит для каждой строки индекс столбца, m
в котором vals[1]
найден.
Комментарии:
1. мой код завершается ошибкой в строке {nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],2) }, почему вы используете ненулевое значение над v и берете только первый столбец?
2. @Codevan Хммм… Да, я думаю, что это приведет к сбою, если одно из значений
vals[0]
илиvals[1]
встречается более одного раза в одной строке. Мое предположение при написании строки, о которой вы упомянули, заключалось в том, что каждое значение будет происходить только один раз в строке, и так начиная сv.nonzero()[:,0] is the row indexes, I could discard this column as I would already know that
v.ненулевой()[0,:]` иv.nonzero()[1:]
соответствует строке0
m
и так далее.3. Я думаю, что возможный способ исправить это — изменить
nonzero()
с помощью функций, которые вместо этого возвращают первое и последнее вхождения ненулевых элементов.4. Я соответствующим образом отредактировал свой вопрос. Знаете ли вы, как подогнать свой ответ сейчас? Спасибо!
Ответ №2:
Полностью основываясь на предыдущем решении, вот пересмотренное:
import torch
vals=[2,8]#let's assume those are the constant values that appear in each row
#target tensor
m=torch.tensor([[1., 2., 7., 8., 5., 2., 6., 5., 8., 4.],
[4., 7., 2., 1., 8., 2., 6., 5., 6., 8.]])
#let's find the indexes of those values
k=m==vals[0]
p=m==vals[1]
v=(k.int() p.int()).bool()
nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],4)
#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1)
#you only need two masks, no matter the size of m. see explanation below
msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
msk_2=(nz_indexes[:,2].repeat(m.shape[1],1).transpose(0,1))<=q
msk_3=(nz_indexes[:,3].repeat(m.shape[1],1).transpose(0,1))>=q
final_mask=msk_0.int() * msk_1.int() msk_2.int() * msk_3.int()
print(final_mask)
и мы, наконец, получаем
tensor([[0, 1, 1, 1, 0, 1, 1, 1, 1, 0],
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)
Комментарии:
1. На будущее, пожалуйста, не меняйте вопрос после того, как вы уже приняли ответ, это не только грубо, но и заставляет людей, которые вам ответили, выглядеть полными идиотами. Что касается самого решения, было бы лучше найти что-то, что лучше обобщает: что, если в массиве N повторений? Наличие N * 2 массивов mask_i не кажется супер элегантным / эффективным.
2. @Ash прав, и вы должны хотя бы проголосовать за его ответ.