#pytorch
Вопрос:
У меня есть тензор из 3 последовательностей, где каждая последовательность имеет длину 2 и состоит из векторов размера 2:
import torch
t = torch.Tensor([[[11,12],[21,22]], [[31,32],[41,42]], [[51,52],[61,62]]])
>>> t
tensor([[[11., 12.],
[21., 22.]],
[[31., 32.],
[41., 42.]],
[[51., 52.],
[61., 62.]]])
Так t
же как и структура t[batch, sequencePos, dataPos]
. Как я могу расширить каждую последовательность так, чтобы она была дополнена новым элементом [01, 02]
(последовательность имеет длину 3), чтобы я получил:
tensor([[[01., 02.],
[11., 12.],
[21., 22.]],
[[01., 02.],
[31., 32.]
[41., 42.]],
[[01., 02.],
[51., 52.],
[61., 62.]]])
Ответ №1:
Вы хотите объединить два тензора axis=1
, первый из которых t
:
>>> t
tensor([[[11., 12.],
[21., 22.]],
[[31., 32.],
[41., 42.]],
[[51., 52.],
[61., 62.]]])
Второе — это договоренность:
>>> arr = torch.arange(t.size(-1))
tensor([0, 1])
Однако сначала нам нужно передать его в правильную форму с помощью torch.reshape
и torch.repeat
:
>>> arr = arr.reshape(1, 1, -1).repeat(len(t), 1, 1)
tensor([[[0, 1]],
[[0, 1]],
[[0, 1]]])
На данный момент arr.shape
это torch.Size([3, 1, 4])
так .
Мы настроены на объединение arr
и t
вместе, либо с torch.cat
:
>>> torch.cat((arr, t), dim=1)
или более элегантно с torch.hstack
:
>>> torch.hstack((arr, t))
tensor([[[ 0., 1.],
[11., 12.],
[21., 22.]],
[[ 0., 1.],
[31., 32.],
[41., 42.]],
[[ 0., 1.],
[51., 52.],
[61., 62.]]])
Обратите внимание, как эта реализация будет работать с любыми 3-мерными входными данными. В следующем примере t
три столбца вместо двух:
>>> t
tensor([[[11., 12., 13.],
[21., 22., 23.]],
[[31., 32., 33.],
[41., 42., 43.]],
[[51., 52., 53.],
[61., 62., 63.]]])
>>> arr = torch.arange(t.size(-1)).reshape(1, 1, -1).repeat(len(t), 1, 1)
>>> torch.cat((arr, t), dim=1)
tensor([[[ 0., 1., 2.],
[11., 12., 13.],
[21., 22., 23.]],
[[ 0., 1., 2.],
[31., 32., 33.],
[41., 42., 43.]],
[[ 0., 1., 2.],
[51., 52., 53.],
[61., 62., 63.]]])
Это можно даже обобщить на n-мерные тензоры. Вам просто нужно позаботиться о reshape
repeat
вызовах и, где количество аргументов зависит от количества измерений.
>>> ones = (1,)*(t.ndim-1)
>>> arr = torch.arange(t.size(-1)).reshape(*ones, -1).repeat(len(t), *ones)
>>> torch.cat((arr, t), dim=-2)
Комментарии:
1. Спасибо за ответ, это решает мою проблему. Мне просто интересно, почему это так сложно в PyTorch, в numpy то же самое может быть достигнуто только одним вызовом
numpy.insert
.