#python #pytorch #forecasting #pytorch-lightning
Вопрос:
В настоящее время я работаю с Pytorch Forecasting, который активно использует Pytorch Lightning. Здесь я применяю тренажер молнии Pytorch для обучения модели трансформатора Временного слияния, примерно следуя наброскам этого примера. Мой примерный учебный код и определение модели выглядят следующим образом:
training = TimeSeriesDataSet(
df_train[lambda x: x.time_idx <= training_cutoff],
time_idx="time_idx",
target="target",
group_ids=["group"],
max_prediction_length=90,
min_encoder_length=365 // 2,
max_encoder_length=365,
time_varying_unknown_reals=["target"],
time_varying_known_reals=["time_idx"]
)
validation = TimeSeriesDataSet.from_dataset(training, df_train, predict=True, stop_randomization=True)
# create dataloaders for model
batch_size = 4
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=res.suggestion(),
hidden_size=16,
attention_head_size=1,
dropout=0.1,
hidden_continuous_size=8,
output_size=7,
loss=QuantileLoss(),
log_interval=10,
reduce_on_plateau_patience=4,
time_varying_reals_encoder=["target"],
time_varying_reals_decoder=["target"]
)
trainer = pl.Trainer(
max_epochs=15,
gpus=0,
weights_summary="top",
gradient_clip_val=0.1,
limit_train_batches=30,
callbacks=[lr_logger, early_stop_callback],
logger=logger,
)
trainer.fit(
tft,
train_dataloader,
val_dataloader
)
Теперь мой вопрос заключается в том, оказывают ли данные валидации какое-либо влияние на оптимизацию модели? Я поиграл с max_prediction_length
параметром, и, похоже, модель работает лучше, когда я устанавливаю временное окно проверки на больший временной интервал. Использует ли Pytorch Lightning Trainer данные проверки для оптимизации модели или я упускаю что-то еще?
Заранее большое спасибо!
Комментарии:
1. Я вижу, вы используете Раннюю Остановку. Вы не указали, как вы создали экземпляр
early_stop_callback
? Он потенциально может использовать показатели проверки для остановки обучения — вот как работает ранняя остановка.2. Спасибо, мне действительно следует более тщательно изучить код, который я копирую!