PyTorch: объединение и сглаживание входных данных разной формы

#python #pytorch #concatenation

#python #pytorch #объединение

Вопрос:

У меня есть несколько входных данных разной формы: (7,), (), (6,) , как я могу объединить и сгладить их до одного сглаживающего ввода. Моя желаемая форма вывода (14,) .

Например: arr1= [1, 2, 3], arr2=6, arr3=[6,7], output=[1,2,3,6,6,7] . Хотя я могу использовать несколько numpy.append, но это было бы некрасиво.

Ответ №1:

Вы можете использовать torch.cat :

 import torch

arr1 = torch.tensor([1, 2, 3])
arr2 = torch.tensor([6])
arr3 = torch.tensor([6,7])

torch.cat((arr1,arr2,arr3))
>>> tensor([1, 2, 3, 6, 6, 7])