Как отключить автоматическую загрузку контрольных точек

#python #pytorch #pytorch-lightning

#python #pytorch #pytorch-молния

Вопрос:

Я пытаюсь запустить цикл по набору параметров, и я не хочу создавать новую сеть для каждого параметра и позволять ей изучать несколько эпох.

В настоящее время мой код выглядит следующим образом:

 def optimize_scale(self, epochs=5, comp_scale=100, scale_list=[1, 100]):
    trainer = pyli.Trainer(gpus=1, max_epochs=epochs)
    
    for scale in scale_list:
        test_model = CustomNN(num_layers=1, scale=scale, lr=1, pad=True, batch_size=1)
        trainer.fit(test_model)
        trainer.test(verbose=True)
        
        del test_model
 

Все работает нормально для первого элемента scale_list , сеть изучает 5 эпох и завершает тест. Все это можно увидеть в консоли. Однако для всех следующих элементов scale_list это не работает, поскольку старая сеть не перезаписывается, а вместо этого автоматически загружается старая контрольная точка при trainer.fit(model) вызове. В консоли это указывается через:

 C:UsersXXXXAppDataRoamingPythonPython39site-packagespytorch_lightningcallbacksmodel_checkpoint.py:623: UserWarning:
Checkpoint directory D:XXXXsrclightning_logsversion_0checkpoints exists and is not empty.
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
train_size = 8   val_size = 1    test_size = 1
Restoring states from the checkpoint path at D:XXXXsrclightning_logsversion_0checkpointsepoch=4-step=39.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at D:XXXXsrclightning_logsversion_0checkpointsepoch=4-step=39.ckpt
 

Следствием этого является то, что второй тест выдает тот же результат, что и загруженная контрольная точка из старой сети, которая уже завершила все 5 эпох. Я подумал, что добавление del test_model может помочь полностью удалить модель, но это не сработало.

В своем поиске я нашел несколько тесно связанных проблем, например: https://github.com/PyTorchLightning/pytorch-lightning/issues/368 . Однако мне не удалось решить мою проблему. Я предполагаю, что это связано с тем фактом, что новая сеть, которая должна перезаписать старую, имеет то же имя / версию и, следовательно, ищет те же контрольные точки.

Если у кого-нибудь есть идея или он знает, как обойти это, я был бы очень благодарен.

Ответ №1:

Я думаю, в ваших настройках вы хотите отключить автоматическую контрольную точку:

 trainer = pyli.Trainer(gpus=1, max_epochs=epochs,enable_checkpointing=False)
 

Возможно, вам потребуется явно сохранять контрольную точку (с другим именем) для каждого сеанса обучения, который вы запускаете.

Вы можете вручную сохранить контрольную точку с помощью:

 trainer.save_checkpoint(f'checkpoint_for_scale_{scale}.pth')
 

Комментарии:

1. Большое вам спасибо. Это вместе с изменением trainer.test(verbose=True) на trainer.test(model=test_model, verbose=True) заставило его работать. У вас случайно нет идеи, как задать имя контрольной точки? Я еще не нашел его в документации Trainer.

2. @MEisebitt пожалуйста, посмотрите мое обновление

3. Возможно, я слишком быстро сказал, что это работает. Интересно, что каждая сеть сейчас обучается, но только первая выполняет эпохи 0-4, все последующие по какой-то причине выполняют только эпоху 4.

4. @MEisebitt Я не очень разбираюсь в lightning, может быть, вам следует создавать Trainer на каждой итерации?

5. Да, я попробовал это, и это работает. К сожалению, это также замедляет работу, по этой причине я хотел, чтобы это было вне цикла. Тем не менее, большое вам спасибо за ваше время и советы 🙂