#pytorch #huggingface-transformers #huggingface-tokenizers
Вопрос:
Я пытаюсь использовать модель hugging face без базы Берта для обучения предсказанию смайликов в твитах, и кажется, что после первой эпохи модель сразу же начинает перестраиваться. Я попробовал следующее:
- Увеличение обучающих данных (я увеличил их с 1x до 10x без какого-либо эффекта)
- Изменение скорости обучения (никаких различий нет)
- Используя разные модели от обнимающего лица (результаты снова были одинаковыми)
- Изменение размера пакета (пошло от 32, 72, 128, 256, 512, 1024)
- Создание модели с нуля, но я столкнулся с проблемами и решил сначала опубликовать здесь, чтобы посмотреть, не упускаю ли я что-нибудь очевидное.
На данный момент я обеспокоен тем, что отдельные твиты не дают достаточной информации для того, чтобы модель могла сделать хорошее предположение, но не будет ли это случайным в этом случае, а не чрезмерным?
Кроме того, время обучения, по-видимому, составляет ~4,5 часа на бесплатных графических процессорах Colab, есть ли способ ускорить это? Я попробовал их TPU, но, похоже, он не распознается.
Вот как выглядят эти данные
И это мой код ниже:
import pandas as pd
import json
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.model_selection import train_test_split
import torch
from transformers import TrainingArguments, Trainer
from transformers import EarlyStoppingCallback
from sklearn.metrics import accuracy_score,precision_score, recall_score, f1_score
import numpy as np
# opening up the data and removing all symbols
df = pd.read_json('/content/drive/MyDrive/computed_results.json.bz2')
df['text_no_emoji'] = df['text_no_emoji'].apply(lambda text: re.sub(r'[^ws]', '', text))
# loading the tokenizer and the model from huggingface
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5).to('cuda')
# test train split
train, test = train_test_split(df[['text_no_emoji', 'emoji_codes']].sample(frac=1), test_size=0.2)
# defining a dataset class that generates the encoder and labels on the fly to minimize memory usage
class Dataset(torch.utils.data.Dataset):
def __init__(self, input, labels=None):
self.input = input
self.labels = labels
def __getitem__(self, pos):
encoded = tokenizer(self.input[pos], truncation=True, max_length=15, padding='max_length')
label = self.labels[pos]
ret = {key: torch.tensor(val) for key, val in encoded.items()}
ret['labels'] = torch.tensor(label)
return ret
def __len__(self):
return len(self.labels)
# training and validation datasets are defined here
train_dataset = Dataset(train['text_no_emoji'].tolist(), train['emoji_codes'].tolist())
val_dataset = Dataset(train['text_no_emoji'].tolist(), test['emoji_codes'].tolist())
# defining the training arguments
args = TrainingArguments(
output_dir="output",
evaluation_strategy="epoch",
logging_steps = 10,
per_device_train_batch_size=1024,
per_device_eval_batch_size=1024,
num_train_epochs=5,
save_steps=3000,
seed=0,
load_best_model_at_end=True,
weight_decay=0.2,
)
# defining the model trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=val_dataset
)
# Training the model
trainer.train()
Результаты: После этого тренировка обычно довольно быстро прекращается из-за ранней остановки
Набор данных можно найти здесь (сжатый размер 39 Мб).
Комментарии:
1. Какие курсы обучения вы использовали? Откуда вы знаете, что он переоборудован?
2. Для скорости обучения я использовал значение по умолчанию, затем либо увеличил, либо уменьшил его на порядок. Что касается переоснащения, я отредактировал вопрос, чтобы показать вам результаты (я ждал, когда будет обновлена последняя версия).
3. Доступен ли этот набор данных публично?
4. Ага. Это из интернет-архива, с тонной постобработки, чтобы сделать все более плавным. Ознакомьтесь с публичным анализом социальных сетей на github, где мы работаем над тем, чтобы делать удивительные вещи с многолетними данными Twitter!
5. Я также могу загрузить образец данных где-нибудь, если вы хотите