Как передать возможные имена классов в distilbert

#python #machine-learning #pytorch #huggingface-transformers #transformer

Вопрос:

Я пытался заставить дистилберта работать, и я загрузил модель и использовал AutoTokenizer.from_pretrained() и AutoModelForSequenceClassification.from_pretrained(). Я уже пару дней пытался передать параметры из раздела «Возможные имена классов» на странице карточки модели huggingface: https://huggingface.co/typeform/distilbert-base-uncased-mnli?candidateLabels=positive, negative, neutralamp;multiClass=trueamp;text=which stocks will go down during new years

Я пытался:

 from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained('.')
model = AutoModelForSequenceClassification.from_pretrained('.')

text = "Dummy text"
text  = "[SEP]Positive[SEP]Neutral[SEP]Negative"
encodedInput = tokenizer(text, return_tensors="pt")
output = model(**encodedInput)
print(output)
 

Предполагается, что он выводит значения для «Положительных», «Нейтральных» и «Отрицательных».
Кто-нибудь знает, как это сделать? Я использую пайторч.

Ответ №1:

Вам не нужно добавлять классы во входной текст, но вы можете определить их в отдельном списке.

Он AutoModelForSequenceClassification сгенерирует логиты, которые при прохождении softmax дадут вам метку класса.

[Еще одно предложение: посмотрите, как я определил токенизатор и модель. Таким образом, они могут быть загружены по пути при запуске кода.]

Пожалуйста, ознакомьтесь с кодом ниже:

 from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

tokenizer = AutoTokenizer.from_pretrained('typeform/distilbert-base-uncased-mnli')
model     = AutoModelForSequenceClassification.from_pretrained('typeform/distilbert-base-uncased-mnli')
classes   = ['positive', 'negative', 'neutral']
text      = "Dummy text"

encodedInput = tokenizer(text, return_tensors="pt")
output       = model(**encodedInput)
output       = torch.softmax(output[0], dim=1).tolist()[0]

max_idx = output.index(max(output))
print(classes[max_idx])