Модель не может предсказать постоянные значения

#keras #neural-network #conv-neural-network

#keras #нейронная сеть #conv-нейронная сеть

Вопрос:

Я хочу обучить модель, которая принимает изображение в качестве входных данных и предсказывает 8 чисел с плавающей точкой. Вот архитектура модели:

 from tensorflow.python.keras.applications.efficientnet import EfficientNetB4 from tensorflow.keras import models, layers   def prepare_model_eff(input_shape):  conv_base = EfficientNetB4(include_top=False, input_shape=input_shape)  conv_base.trainable = True # That's done deliberately!  model = models.Sequential()  model.add(conv_base)  model.add(layers.GlobalMaxPooling2D())  model.add(layers.Dropout(rate=0.2, ))  model.add(layers.Dense(8, bias_initializer=Constant(0.0)))  return model  

Функция потерь: MSE

Показатель: RMSE

Чтобы протестировать архитектуру, я пытаюсь предсказать постоянные значения: 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0.

В наборе поездов ~1500 изображений и ~300 изображений в наборе проверки, так что есть 1800 изображений с одинаковыми постоянными выходами.

Я ожидаю, что модель поймет идею прогнозирования тех же значений и даст RMSE 0,0000000, но удивительно, что после 80 эпох RMSE моделей составляет всего 0,3.

Поскольку тестовое задание очень простое, я полагаю, что с архитектурой может быть что-то не так. Может быть, я стреляю себе в ногу? Вот краткое описание модели для вашей справки:

 Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param #  ================================================================= efficientnetb4 (Functional) (None, 12, 12, 1792) 17673823  _________________________________________________________________ global_max_pooling2d (Global (None, 1792) 0  _________________________________________________________________ dropout (Dropout) (None, 1792) 0  _________________________________________________________________ dense (Dense) (None, 8) 14344  ================================================================= Total params: 17,688,167 Trainable params: 17,562,960 Non-trainable params: 125,207