mt5forceconditionalgeneration с Pytorch-lightning выдает attribute_error

#optimization #pytorch #pytorch-lightning

#оптимизация #pytorch #pytorch-lightning

Вопрос:

Не могу справиться с этой проблемой в течение нескольких дней, так как я новичок в NLP, и фактическое решение может быть действительно простым

 class QAModel(pl.LightningDataModule):

  def __init__(self):
    super().__init__()
    self.model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)

  def forward(self, input_ids, attention_mask, labels=None):
    output = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        labels=labels
    )

    return output.loss, output.logits
  
  def training_step(self, batch, batch_idx):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log('train_loss', loss, prog_bar=True, logger=True)
    return loss
  
  def validation_step(self, batch, batch_idx):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log('val_loss', loss, prog_bar=True, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log('test_loss', loss, prog_bar=True, logger=True)
    return loss

  def configure_optimizers(self):
    return AdamW(self.parameters(), lr=0.0001)
model = QAModel()
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath='/content/checkpoints',
    filename='best-checkpoint',
    save_top_k=1,
    verbose=True,
    monitor='val_loss',
    mode='min'
)
trainer = pl.Trainer(
    checkpoint_callback=checkpoint_callback,
    max_epochs=N_EPOCHS,
    gpus=1,
    progress_bar_refresh_rate=30
)
trainer.fit(model, data_module)
 

Запуск этого кода дает мне
Ошибка атрибута: объект ‘QAModel’ не имеет атрибута ‘automatic_optimization’
после функции fit ()
Вероятно, проблема в MT5ForConditionalGeneration, так как после передачи его в funtion() мы получили ту же ошибку

Ответ №1:

Попробуйте наследовать pl.LightingModule вместо pl.LightningDataModule . Это правильный выбор для определения класса модели.