Загрузка слоя пула простого трансформатора

#python #nlp #pytorch #bert-language-model #simpletransformers

#python #nlp #pytorch #bert-language-model #simpletransformers

Вопрос:

У меня есть точно настроенная модель представления простого трансформатора. Теперь я хочу сохранить веса только слоя пула в формате pickle и поместить его в слой пула другого пользовательского автоматического кодировщика, который я разрабатываю.Как я могу это сделать, используя pytorch и python?

Ответ №1:

Рядом с каждым модулем PyTorch вызывается объект state_dict , который позволяет сопоставлять любой параметр с соответствующей тензорной переменной (подробнее об этом здесь ). С помощью этой утилиты вы можете легко сохранять и загружать параметры, но имейте в виду, что вы должны быть уверены в том, что вы хотите сделать заранее, как семантически (с точки зрения машинного обучения), так и синтаксически (совместимость формы и …)! Приведенная ниже реализация заменит любой параметр словом pooling в его имени на соответствующую переменную из модели, которую мы сохранили ранее.

 finetuned_model = BertLMHeadModel.from_pretrained('bert-base-cased')
torch.save(finetuned_model.state_dict(), "finetuned_model.pth")
finetuned_model_state_dict = torch.load("finetuned_model.pth")
new_model = BertLMHeadModel.from_pretrained('bert-base-cased')
new_model_state_dict = new_model.state_dict()
for key, value in new_model_state_dict.items():
  if key.find('pooling')!=-1:
    new_model_state_dict.update({key: value})
 

Комментарии:

1. @Parmida Granfar помог ли ответ?