Pytorch: загрузка контрольной точки из пакета без повторного перебора набора данных

#neural-network #pytorch

Вопрос:

Вместо загрузки с контрольной точки по эпохе мне нужно иметь возможность загружать из пакета. Я понимаю, что это не оптимально, но, поскольку у меня есть ограниченное время для обучения, прежде чем мое обучение будет прервано (бесплатная версия Google colab), мне нужно иметь возможность загружать из пакета, который он остановил, или вокруг этого пакета.

Я также не хочу повторять все данные снова, но продолжаю с данными, которые модель еще не видела.

Мой нынешний подход, который не работает:

 def save_checkpoint(state, file=checkpoint_file):
  torch.save(state, file)

def load_checkpoint(checkpoint):
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  train_loss = checkpoint['train_loss']
  val_loss = checkpoint['val_loss']
  epoch = checkpoint['epoch']
  step = checkpoint['step']
  batch = checkpoint['batch']
  return model, optimizer, train_loss, val_loss, epoch, step, batch
 

Хотя он загружает веса с того места, где остановился, он снова повторяет все данные.

Кроме того, нужно ли мне вообще захватывать train_loss и val_loss ? Я не вижу разницы в выводимых потерях, когда я включаю их или нет. Таким образом, я предполагаю, что он уже включен в model.load_state_dict (?)

Я предполагаю, что захват шага и пакета не будет работать таким образом, и мне действительно нужно включить в свой class DataSet ? У меня уже есть это в DataSet классе

    def __getitem__(self, idx):
    question = self.data_qs[idx]
    answer1 = self.data_a1s[idx]
    answer2 = self.data_a2s[idx]
    target = self.targets[idx]
 

Итак, может ли это быть полезно?

Ответ №1:

Вы можете достичь своей цели, создав пользовательский класс набора данных со свойством self.start_index=step*batch , и в вашей __getitem__ функции должен быть новый индекс (self.start_index idx)%len(self.data_qs) , если вы создадите свой загрузчик shuffle=False данных, тогда эти трюки будут работать.

Кроме того, с shuffle=True помощью вы можете поддерживать сопоставитель индексов и его необходимо проверить.