#optimization #pytorch #pytorch-lightning
#оптимизация #pytorch #pytorch-lightning
Вопрос:
Не могу справиться с этой проблемой в течение нескольких дней, так как я новичок в NLP, и фактическое решение может быть действительно простым
class QAModel(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
def forward(self, input_ids, attention_mask, labels=None):
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
return output.loss, output.logits
def training_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('train_loss', loss, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('val_loss', loss, prog_bar=True, logger=True)
return loss
def test_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('test_loss', loss, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
return AdamW(self.parameters(), lr=0.0001)
model = QAModel()
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
dirpath='/content/checkpoints',
filename='best-checkpoint',
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min'
)
trainer = pl.Trainer(
checkpoint_callback=checkpoint_callback,
max_epochs=N_EPOCHS,
gpus=1,
progress_bar_refresh_rate=30
)
trainer.fit(model, data_module)
Запуск этого кода дает мне
Ошибка атрибута: объект ‘QAModel’ не имеет атрибута ‘automatic_optimization’
после функции fit ()
Вероятно, проблема в MT5ForConditionalGeneration, так как после передачи его в funtion() мы получили ту же ошибку
Ответ №1:
Попробуйте наследовать pl.LightingModule
вместо pl.LightningDataModule
. Это правильный выбор для определения класса модели.