Bert для классификации текста — логиты и метки должны иметь одинаковую форму ((50, 2) против (2, 1))

#text #classification #bert-language-model

Вопрос:

Я тренирую модель БЕРТА для классификации текста в tensorflow, но я получаю эту ошибку:

Ошибка значения: логины и метки должны иметь одинаковую форму ((50, 2) против (2, 1))

Я должен добавить, что форма моих меток на самом деле (678, 2)

 input_ids = tf.keras.layers.Input(shape=(SEQ_LEN,), name = 'input_ids', dtype='int32') mask = tf.keras.layers.Input(shape=(SEQ_LEN,), name = 'attention_mask', dtype='int32')  embeddings = bert(input_ids, attention_mask=mask)[0]  X=tf.keras.layers.GlobalMaxPool1D()(embeddings)  X = tf.keras.layers.BatchNormalization()(X) #for normalization X = tf.keras.layers.Dense(128, activation = 'relu')(X) X = tf.keras.layers.Dropout(0.1)(X) X = tf.keras.layers.Dense(32, activation='relu')(X) y = tf.keras.layers.Dense(2, activation='softmax', name='output')(X) model = tf.keras.Model(inputs=[input_ids, mask], outputs=y)  #model.layers[2].trainable = False #to freeze bert because it takes to long model.summary() Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to  ================================================================================================== input_ids (InputLayer) [(None, 50)] 0  __________________________________________________________________________________________________ attention_mask (InputLayer) [(None, 50)] 0  __________________________________________________________________________________________________ tf_bert_model (TFBertModel) TFBaseModelOutputWit 109482240 input_ids[0][0]   attention_mask[0][0]  __________________________________________________________________________________________________ global_max_pooling1d (GlobalMax (None, 768) 0 tf_bert_model[0][0]  __________________________________________________________________________________________________ batch_normalization (BatchNorma (None, 768) 3072 global_max_pooling1d[0][0]  __________________________________________________________________________________________________ dense (Dense) (None, 128) 98432 batch_normalization[0][0]  __________________________________________________________________________________________________ dropout_37 (Dropout) (None, 128) 0 dense[0][0]  __________________________________________________________________________________________________ dense_1 (Dense) (None, 32) 4128 dropout_37[0][0]  __________________________________________________________________________________________________ output (Dense) (None, 2) 66 dense_1[0][0]  ================================================================================================== Total params: 109,587,938 Trainable params: 109,586,402 Non-trainable params: 1,536  model.compile(loss='binary_crossentropy',  optimizer='adam',  metrics=['accuracy'])  history = model.fit(train, validation_data=val, epochs=3, batch_size=4)  

Комментарии:

1. не могли бы вы, пожалуйста, записать, что bert такое и журнал ошибок?