#deep-learning #pytorch #bert-language-model #word-embedding
Вопрос:
У меня есть модель PyTorch со следующей архитектурой: BERT -> dropout -> classifier -> loss_function
. Я точно настроил модель в своем наборе данных и использовал прогнозы. Теперь я хочу получить доступ к выходным данным модели BERT перед слоем классификатора. Я хочу использовать их в качестве вложения моего предложения.
Мой вопрос в следующем: 1) как мне сохранить выходные данные каждого из 12 скрытых слоев для моего ввода? 2) Какое из полей, output.dense.weight
, output.dense.bias
, output.LayerNorm.weight
output.LayerNorm.bias
я должен использовать?
Вот что я уже пробовал
Я создал модель БЕРТА из предварительно bert_base
подготовленной и использованной model.load_state_dict(ckpt)
. Теперь я могу делать прогнозы с помощью этой модели:
tokenized_text = np.array(tokenizer.encode(text))
padded_text = np.zeroes(50)[:tokenized_text.shape[0]] = tokenized_text
input_ids = torch.tensor(padded_text).reshape(1,50)
input_ids = input_ids.type(torch.long)
model(input_ids)
Приведенный выше код даст мне [1,50,768]
и [1, 768]
тензор.
Дополнительная информация
Когда я загружаю контрольную точку с torch.load
помощью и у меня есть доступ к следующим векторам bert.embeddings.*
и bert.encoder.layer.*
. Для каждого слоя у меня есть
bert.encoder.layer.x.attention.self.*
(6 векторов) bert.encoder.layer.x.attention.output.*
(4 вектора) bert.encoder.layer.x.intermediate.dense.weight
bert.encoder.layer.x.intermediate.dense.bias
bert.encoder.layer.x.output.dense.weight
bert.encoder.layer.x.output.dense.bias
bert.encoder.layer.x.output.LayerNorm.weight
bert.encoder.layer.x.output.LayerNorm.bias
Комментарии:
1. Ответ может быть похож на этот вопрос: datascience.stackexchange.com/questions/62658/… … но любая другая информация о слое, упомянутом выше, приветствуется.
2. Ничто из
output.dense.weight, output.dense.bias, output.LayerNorm.weight output.LayerNorm.bias
этого не является «выходом».3. Если вы хотите использовать «вывод модели БЕРТА перед слоем классификатора», вы должны сделать это в
forward
функции. На самом деле вам не нужно (и в большинстве случаев не следует) менять способ загрузки веса.4. Но вам придется разобраться, как
forward
BERT
написана используемая вами реализация. Или укажите, какая это реализация.