Хочу точно настроить предварительно обученную Роберту из Huggingface на мои собственные данные для суммирования текста

#text #pytorch #summarization #roberta-language-model

#текст #pytorch #суммирование #roberta-language-model

Вопрос:

Я новичок в этом. Пожалуйста, помогите мне найти решение. Я использовал RobertaTokenizerFast для разметки текста и резюме (max_token_length 200 и 50 соответственно). План состоит в том, чтобы использовать RoBERTa в качестве первого слоя. Затем уплотните его вывод, чтобы он соответствовал целевой сводке, используя conv2d, maxpool2d и dense. Вывод последнего плотного слоя представляет собой вектор с плавающей запятой. Итак, я нормализовал целевой вектор, содержащий длинные входные идентификаторы, в значения с плавающей запятой (от 0 до 1). Наконец, я использовал функцию CrossEntropy для получения потерь.

 class Summarizer(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.roberta = RobertaModel.from_pretrained('roberta-base', return_dict = True, is_decoder=True, use_cache=False) 
    self.convlayer = torch.nn.Conv2d(in_channels=BATCH_SIZE, out_channels=1, kernel_size=4) 
                                              ## BATCH_SIZE=20
    self.relu = torch.nn.ReLU()
    self.fc = torch.nn.Linear(in_features=97*381, out_features=50)
    self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

  def forward(self, input_ids, attention_mask, labels):
    output = self.roberta(input_ids=input_ids, attention_mask=attention_mask) 
    x = output['last_hidden_state']
    x = torch.unsqueeze(x, 0)
    x = self.convlayer(x)
    x = self.relu(x)
    x = F.max_pool2d(x, kernel_size=4, stride=2)
    x = x.squeeze().flatten()
    x = self.fc(x)
    output = self.relu(x)
    crossent_loss = self.cross_entropy_loss(labels, output)
    return crossent_loss, output
  
  def training_step(self, batch, batch_idx):
    input_ids = batch['text_input_ids']
    attention_mask = batch['text_attention_mask']

    l = batch['labels'].float()
    l = torch.tensor(l/torch.linalg.norm(l))

    labels = l # normalized labels in (0,1)
    labels_attention_mask = batch['labels_attention_mask']


    loss, outputs = self(
                         input_ids = input_ids,
                         attention_mask = attention_mask,
                         labels = labels
                         )
    self.log('train_loss', loss, prog_bar = True, logger = True)
    return loss

  def validation_step(self, batch, batch_idx): 
    input_ids = batch['text_input_ids']
    attention_mask = batch['text_attention_mask']

    l = batch['labels'].float()
    l = torch.tensor(l/torch.linalg.norm(l))
    
    labels = l
    labels_attention_mask = batch['labels_attention_mask']

    loss, outputs = self(
                         input_ids = input_ids,
                         attention_mask = attention_mask,
                         labels = labels
                         )
    self.log('val_loss', loss, prog_bar = True, logger = True)
    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch['text_input_ids']
    attention_mask = batch['text_attention_mask']
    
    l = batch['labels'].float()
    l = torch.tensor(l/torch.linalg.norm(l))
    
    labels = l
    labels_attention_mask = batch['labels_attention_mask']

    loss, outputs = self(
                         input_ids = input_ids,
                         attention_mask = attention_mask,
                         labels = labels
                         )
    self.log('test_loss', loss, prog_bar = True, logger = True)
    return loss

  def configure_optimizers(self):
    return AdamW(self.parameters(), lr=0.0001)
 

Обучение с использованием pl.Trainer возвратов ValueError: Expected input batch_size (20) to match target batch_size (50).
Я не смог получить ошибку.

Комментарии:

1. Обобщение текста — это проблема seq2seq, то, что вы делаете, ближе к классификации. Вы можете взглянуть на это huggingface.co/transformers/model_doc/encoderdecoder.html , чтобы создать пользовательскую модель кодера-декодера

2. Если у вас есть ограничение по времени, попробуйте использовать эту библиотеку, которая оборачивает библиотеку transformers simpletransformers.ai/docs/seq2seq-model