Использует ли тренер Pytorch Lightning данные проверки для оптимизации весов моделей?

#python #pytorch #forecasting #pytorch-lightning

Вопрос:

В настоящее время я работаю с Pytorch Forecasting, который активно использует Pytorch Lightning. Здесь я применяю тренажер молнии Pytorch для обучения модели трансформатора Временного слияния, примерно следуя наброскам этого примера. Мой примерный учебный код и определение модели выглядят следующим образом:

 training = TimeSeriesDataSet(
    df_train[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="target",
    group_ids=["group"],
    max_prediction_length=90,
    min_encoder_length=365 // 2,
    max_encoder_length=365, 
    time_varying_unknown_reals=["target"], 
    time_varying_known_reals=["time_idx"]
)

validation = TimeSeriesDataSet.from_dataset(training, df_train, predict=True, stop_randomization=True)

# create dataloaders for model
batch_size = 4  
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=res.suggestion(),
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=7,  
    loss=QuantileLoss(),
    log_interval=10,  
    reduce_on_plateau_patience=4,
    time_varying_reals_encoder=["target"],
    time_varying_reals_decoder=["target"]
)

trainer = pl.Trainer(
    max_epochs=15,
    gpus=0,
    weights_summary="top",
    gradient_clip_val=0.1,
    limit_train_batches=30,
    callbacks=[lr_logger, early_stop_callback],
    logger=logger,
)

trainer.fit(
    tft,
    train_dataloader,
    val_dataloader
)
 

Теперь мой вопрос заключается в том, оказывают ли данные валидации какое-либо влияние на оптимизацию модели? Я поиграл с max_prediction_length параметром, и, похоже, модель работает лучше, когда я устанавливаю временное окно проверки на больший временной интервал. Использует ли Pytorch Lightning Trainer данные проверки для оптимизации модели или я упускаю что-то еще?

Заранее большое спасибо!

Комментарии:

1. Я вижу, вы используете Раннюю Остановку. Вы не указали, как вы создали экземпляр early_stop_callback ? Он потенциально может использовать показатели проверки для остановки обучения — вот как работает ранняя остановка.

2. Спасибо, мне действительно следует более тщательно изучить код, который я копирую!