Как правильно упаковать дополненную последовательность с измененным классом набора данных

#neural-network #pytorch #lstm #tensor #zero-padding

Вопрос:

Итак, у меня есть класс набора данных, который я создал, который принимает массив 3D numpy и длины фигур для pack_padded_sequence :

 class MyDataset(data.Dataset):
     def __init__(self, dataset, data_shape):
          self.dataset = dataset
          self.transform = MyToTensor(data_shape)
 

и я создал свой собственный класс ToTensor:

 class MyToTensor(object):
     def __init__(self, data_shape):
          self.data_shape = data_shape
     def __call__(self, data):
          data = torch.from_numpy(data)
          return rnn.pack_padded_sequence(data, lengths=self.data_shape, batch_first=True)
 

Но по какой-то причине print(list(MyDataset(dataset, data_shape))) , когда я получаю обычный тензорный объект, он возвращается без удаления отступа.

Для получения дополнительной информации о моих входных данных, dataset это массив 3D numpy в порядке: batch size, sequence length, features и data_shape-это список, соответствующий размеру batch_size с числом, представляющим длину последовательности.

Длина последовательности также находится в порядке от наибольшей последовательности до наименьшего размера последовательности

Пример моих вкладов:

 [[[0 0.33000001311302185 1]
  [0 0.4300000071525574 1]
  [0 0.3799999952316284 1]
  ...
  [0 0.33000001311302185 1]
  [0 0.28999999165534973 1]
  [0 0.33000001311302185 1]]

 [[6 0.800000011920929 3]
  [5 0.7300000190734863 3]
  [7 0.8199999928474426 3]
  ...
  [4 0.699999988079071 3]
  [5 0.7799999713897705 3]
  [5 0.7799999713897705 3]]

 [[3 1.0 5]
  [3 1.0 5]
  [3 1.0 5]
  ...
  [3 1.0 5]
  [3 1.0 5]
  [3 1.0 5]]

 ...

 [[4.0 0.7599999904632568 3.0]
  [6.0 0.8100000023841858 3.0]
  [6.0 1.0 3.0]
  ...
  [nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[8.0 1.0 0.0]
  [8.0 0.9100000262260437 0.0]
  [9.0 1.0 0.0]
  ...
  [nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[5.0 1.0 1.0]
  [4.0 1.0 1.0]
  [4.0 1.0 1.0]
  ...
  [nan nan nan]
  [nan nan nan]
  [nan nan nan]]]
 

И соответствующая форма данных:

 (235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 235, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 18, 18, 18)