#text #classification #bert-language-model
Вопрос:
Я тренирую модель БЕРТА для классификации текста в tensorflow, но я получаю эту ошибку:
Ошибка значения: логины и метки должны иметь одинаковую форму ((50, 2) против (2, 1))
Я должен добавить, что форма моих меток на самом деле (678, 2)
input_ids = tf.keras.layers.Input(shape=(SEQ_LEN,), name = 'input_ids', dtype='int32') mask = tf.keras.layers.Input(shape=(SEQ_LEN,), name = 'attention_mask', dtype='int32') embeddings = bert(input_ids, attention_mask=mask)[0] X=tf.keras.layers.GlobalMaxPool1D()(embeddings) X = tf.keras.layers.BatchNormalization()(X) #for normalization X = tf.keras.layers.Dense(128, activation = 'relu')(X) X = tf.keras.layers.Dropout(0.1)(X) X = tf.keras.layers.Dense(32, activation='relu')(X) y = tf.keras.layers.Dense(2, activation='softmax', name='output')(X) model = tf.keras.Model(inputs=[input_ids, mask], outputs=y) #model.layers[2].trainable = False #to freeze bert because it takes to long model.summary() Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_ids (InputLayer) [(None, 50)] 0 __________________________________________________________________________________________________ attention_mask (InputLayer) [(None, 50)] 0 __________________________________________________________________________________________________ tf_bert_model (TFBertModel) TFBaseModelOutputWit 109482240 input_ids[0][0] attention_mask[0][0] __________________________________________________________________________________________________ global_max_pooling1d (GlobalMax (None, 768) 0 tf_bert_model[0][0] __________________________________________________________________________________________________ batch_normalization (BatchNorma (None, 768) 3072 global_max_pooling1d[0][0] __________________________________________________________________________________________________ dense (Dense) (None, 128) 98432 batch_normalization[0][0] __________________________________________________________________________________________________ dropout_37 (Dropout) (None, 128) 0 dense[0][0] __________________________________________________________________________________________________ dense_1 (Dense) (None, 32) 4128 dropout_37[0][0] __________________________________________________________________________________________________ output (Dense) (None, 2) 66 dense_1[0][0] ================================================================================================== Total params: 109,587,938 Trainable params: 109,586,402 Non-trainable params: 1,536 model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) history = model.fit(train, validation_data=val, epochs=3, batch_size=4)
Комментарии:
1. не могли бы вы, пожалуйста, записать, что
bert
такое и журнал ошибок?