Модель подходит после первой эпохи

#pytorch #huggingface-transformers #huggingface-tokenizers

Вопрос:

Я пытаюсь использовать модель hugging face без базы Берта для обучения предсказанию смайликов в твитах, и кажется, что после первой эпохи модель сразу же начинает перестраиваться. Я попробовал следующее:

  1. Увеличение обучающих данных (я увеличил их с 1x до 10x без какого-либо эффекта)
  2. Изменение скорости обучения (никаких различий нет)
  3. Используя разные модели от обнимающего лица (результаты снова были одинаковыми)
  4. Изменение размера пакета (пошло от 32, 72, 128, 256, 512, 1024)
  5. Создание модели с нуля, но я столкнулся с проблемами и решил сначала опубликовать здесь, чтобы посмотреть, не упускаю ли я что-нибудь очевидное.

На данный момент я обеспокоен тем, что отдельные твиты не дают достаточной информации для того, чтобы модель могла сделать хорошее предположение, но не будет ли это случайным в этом случае, а не чрезмерным?

Кроме того, время обучения, по-видимому, составляет ~4,5 часа на бесплатных графических процессорах Colab, есть ли способ ускорить это? Я попробовал их TPU, но, похоже, он не распознается.

Вот как выглядят эти данные

снимок экрана набора данных

И это мой код ниже:

 import pandas as pd
import json
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.model_selection import train_test_split
import torch
from transformers import TrainingArguments, Trainer
from transformers import EarlyStoppingCallback
from sklearn.metrics import accuracy_score,precision_score, recall_score, f1_score
import numpy as np

# opening up the data and removing all symbols
df = pd.read_json('/content/drive/MyDrive/computed_results.json.bz2')
df['text_no_emoji'] = df['text_no_emoji'].apply(lambda text: re.sub(r'[^ws]', '', text))


# loading the tokenizer and the model from huggingface
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5).to('cuda')

# test train split
train, test = train_test_split(df[['text_no_emoji', 'emoji_codes']].sample(frac=1), test_size=0.2)

# defining a dataset class that generates the encoder and labels on the fly to minimize memory usage
class Dataset(torch.utils.data.Dataset):    
    def __init__(self, input, labels=None):
        self.input = input
        self.labels = labels

    def __getitem__(self, pos):
        encoded = tokenizer(self.input[pos], truncation=True, max_length=15, padding='max_length')
        label = self.labels[pos]
        ret = {key: torch.tensor(val) for key, val in encoded.items()}

        ret['labels'] = torch.tensor(label)
        return ret

    def __len__(self):
        return len(self.labels)

# training and validation datasets are defined here
train_dataset = Dataset(train['text_no_emoji'].tolist(), train['emoji_codes'].tolist())
val_dataset = Dataset(train['text_no_emoji'].tolist(), test['emoji_codes'].tolist())

# defining the training arguments
args = TrainingArguments(
    output_dir="output",
    evaluation_strategy="epoch",
    logging_steps = 10,
    per_device_train_batch_size=1024,
    per_device_eval_batch_size=1024,
    num_train_epochs=5,
    save_steps=3000,
    seed=0,
    load_best_model_at_end=True,
    weight_decay=0.2,
)

# defining the model trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

# Training the model
trainer.train()
 

Результаты: После этого тренировка обычно довольно быстро прекращается из-за ранней остановки

Набор данных можно найти здесь (сжатый размер 39 Мб).

Результаты 3 эпох

Комментарии:

1. Какие курсы обучения вы использовали? Откуда вы знаете, что он переоборудован?

2. Для скорости обучения я использовал значение по умолчанию, затем либо увеличил, либо уменьшил его на порядок. Что касается переоснащения, я отредактировал вопрос, чтобы показать вам результаты (я ждал, когда будет обновлена последняя версия).

3. Доступен ли этот набор данных публично?

4. Ага. Это из интернет-архива, с тонной постобработки, чтобы сделать все более плавным. Ознакомьтесь с публичным анализом социальных сетей на github, где мы работаем над тем, чтобы делать удивительные вещи с многолетними данными Twitter!

5. Я также могу загрузить образец данных где-нибудь, если вы хотите