Отредактируйте архитектуру предварительно обученной модели BERT

#tensorflow #keras #nlp #bert-language-model #transfer-learning

Вопрос:

Я нашел модель BERT из репозитория Google на GitHub. Скачал его, получил конфигурационный файл json и загрузил модель.

 import json

bert_config_file = os.path.join(gs_folder_bert, "/content/drive/My Drive/Colab Notebooks/bert_config.json")

config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())

bert_config = bert.configs.BertConfig.from_dict(config_dict)

bert_classifier, bert_encoder = bert.bert_models.classifier_model(bert_config, num_labels=2)
print(bert_encoder.summary())

Model: "bert_encoder_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_word_ids (InputLayer)     [(None, None)]       0                                            
__________________________________________________________________________________________________
word_embeddings (OnDeviceEmbedd (None, None, 128)    3906816     input_word_ids[0][0]             
__________________________________________________________________________________________________
input_type_ids (InputLayer)     [(None, None)]       0                                            
__________________________________________________________________________________________________
position_embedding (PositionEmb (None, None, 128)    65536       word_embeddings[0][0]            
__________________________________________________________________________________________________
type_embeddings (OnDeviceEmbedd (None, None, 128)    256         input_type_ids[0][0]             
__________________________________________________________________________________________________
add (Add)                       (None, None, 128)    0           word_embeddings[0][0]            
                                                                 position_embedding[0][0]         
                                                                 type_embeddings[0][0]            
__________________________________________________________________________________________________
embeddings/layer_norm (LayerNor (None, None, 128)    256         add[0][0]                        
__________________________________________________________________________________________________
dropout (Dropout)               (None, None, 128)    0           embeddings/layer_norm[0][0]      
__________________________________________________________________________________________________
input_mask (InputLayer)         [(None, None)]       0                                            
__________________________________________________________________________________________________
self_attention_mask (SelfAttent (None, None, None)   0           dropout[0][0]                    
                                                                 input_mask[0][0]                 
__________________________________________________________________________________________________
transformer/layer_0 (Transforme (None, None, 128)    198272      dropout[0][0]                    
                                                                 self_attention_mask[0][0]        
__________________________________________________________________________________________________
transformer/layer_1 (Transforme (None, None, 128)    198272      transformer/layer_0[0][0]        
                                                                 self_attention_mask[0][0]        
__________________________________________________________________________________________________
tf.__operators__.getitem (Slici (None, 128)          0           transformer/layer_1[0][0]        
__________________________________________________________________________________________________
pooler_transform (Dense)        (None, 128)          16512       tf.__operators__.getitem[0][0]   
==================================================================================================
Total params: 4,385,920
Trainable params: 4,385,920
Non-trainable params: 0
 

Мне нужна модель, в которую я мог бы вводить свои собственные векторы встраивания непосредственно на слой 0 трансформатора, который является 11-м слоем в модели. Поэтому мне не нужны первые 10 слоев предварительной обработки. И, наконец, я хочу изменить активацию вывода, чтобы я мог выполнить регрессию.

Как я могу отредактировать архитектуру модели или создать новую модель, содержащую эти слои трансформатора? Заранее спасибо.