Расширьте тензор последовательностей так, чтобы каждая последовательность была предварена

#pytorch

Вопрос:

У меня есть тензор из 3 последовательностей, где каждая последовательность имеет длину 2 и состоит из векторов размера 2:

 import torch
t = torch.Tensor([[[11,12],[21,22]], [[31,32],[41,42]], [[51,52],[61,62]]])
>>> t
tensor([[[11., 12.],
         [21., 22.]],

        [[31., 32.],
         [41., 42.]],

        [[51., 52.],
         [61., 62.]]])
 

Так t же как и структура t[batch, sequencePos, dataPos] . Как я могу расширить каждую последовательность так, чтобы она была дополнена новым элементом [01, 02] (последовательность имеет длину 3), чтобы я получил:

 tensor([[[01., 02.],
         [11., 12.],             
         [21., 22.]],

        [[01., 02.],
         [31., 32.]
         [41., 42.]],

        [[01., 02.],
         [51., 52.],
         [61., 62.]]])
 

Ответ №1:

Вы хотите объединить два тензора axis=1 , первый из которых t :

  >>> t
tensor([[[11., 12.],
         [21., 22.]],

        [[31., 32.],
         [41., 42.]],

        [[51., 52.],
         [61., 62.]]])
 

Второе — это договоренность:

 >>> arr = torch.arange(t.size(-1))
tensor([0, 1])
 

Однако сначала нам нужно передать его в правильную форму с помощью torch.reshape и torch.repeat :

 >>> arr = arr.reshape(1, 1, -1).repeat(len(t), 1, 1)
tensor([[[0, 1]],

        [[0, 1]],

        [[0, 1]]])
 

На данный момент arr.shape это torch.Size([3, 1, 4]) так .

Мы настроены на объединение arr и t вместе, либо с torch.cat :

 >>> torch.cat((arr, t), dim=1)
 

или более элегантно с torch.hstack :

 >>> torch.hstack((arr, t))
tensor([[[ 0.,  1.],
         [11., 12.],
         [21., 22.]],

        [[ 0.,  1.],
         [31., 32.],
         [41., 42.]],

        [[ 0.,  1.],
         [51., 52.],
         [61., 62.]]])
 

Обратите внимание, как эта реализация будет работать с любыми 3-мерными входными данными. В следующем примере t три столбца вместо двух:

 >>> t
tensor([[[11., 12., 13.],
         [21., 22., 23.]],

        [[31., 32., 33.],
         [41., 42., 43.]],

        [[51., 52., 53.],
         [61., 62., 63.]]])

>>> arr = torch.arange(t.size(-1)).reshape(1, 1, -1).repeat(len(t), 1, 1)
>>> torch.cat((arr, t), dim=1)
tensor([[[ 0.,  1.,  2.],
         [11., 12., 13.],
         [21., 22., 23.]],

        [[ 0.,  1.,  2.],
         [31., 32., 33.],
         [41., 42., 43.]],

        [[ 0.,  1.,  2.],
         [51., 52., 53.],
         [61., 62., 63.]]])
 

Это можно даже обобщить на n-мерные тензоры. Вам просто нужно позаботиться о reshape repeat вызовах и, где количество аргументов зависит от количества измерений.

 >>> ones = (1,)*(t.ndim-1)
>>> arr = torch.arange(t.size(-1)).reshape(*ones, -1).repeat(len(t), *ones)
>>> torch.cat((arr, t), dim=-2)
 

Комментарии:

1. Спасибо за ответ, это решает мою проблему. Мне просто интересно, почему это так сложно в PyTorch, в numpy то же самое может быть достигнуто только одним вызовом numpy.insert .