ключевой набор данных, потерянный во время обучения с использованием тренажера Hugging Face

#python #huggingface-transformers

#python #huggingface-трансформеры

Вопрос:

Я следую материалу курса Hugging Face: https://huggingface.co/course/chapter7/3?fw=pt (кстати, отличная штука!). Однако теперь я столкнулся с проблемой.

Когда я запускаю обучение и оцениваю, используя data_collator по умолчанию, все идет нормально. Но когда я использую пользовательский whole_word_masking_data_collator, он не работает, потому что он пропускает ключ «word_ids».

Мои данные следующие:

 DatasetDict({
train: Dataset({
    features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids', 'word_ids'],
    num_rows: 30639
})
test: Dataset({
    features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids', 'word_ids'],
    num_rows: 29946
})
unsupervised: Dataset({
    features: ['attention_mask', 'input_ids', 'labels', 'token_type_ids', 'word_ids'],
    num_rows: 61465
})
})
 

Когда я использую свой whole_word_masking_data_collator следующим образом, все в порядке:

 whole_word_masking_data_collator([lm_datasets["train"][0]])
 

Однако, когда я использую его в таком тренажере:

 from transformers import Trainer

trainer = Trainer(
    model=masked_model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["test"],
    data_collator=whole_word_masking_data_collator,
)
 

Это дает мне следующую ошибку:

 KeyError: 'word_ids'
 

Что я нахожу странным, потому что этот ключ явно нажимается в данных, а функция whole_word_masking_data_collator отлично работает автономно.

Когда я проверил ключи в своей функции, я обнаружил, что ключ действительно отсутствует. Я получил только эти ключи:

 dict_keys(['attention_mask', 'input_ids', 'labels', 'token_type_ids'])
 

Итак, мой вопрос: был ли в моем коде пропущен ключ «word_ids»?

Комментарии:

1. Я уже нашел, где что-то пошло не так. Но я пока не знаю, как это исправить. Кажется, что тренер игнорирует это. Смотрите сообщение: следующие столбцы в наборе оценок не имеют соответствующего аргумента BertForMaskedLM.forward и были проигнорированы: word_ids .