#keras #neural-network #conv-neural-network
#keras #нейронная сеть #conv-нейронная сеть
Вопрос:
Я хочу обучить модель, которая принимает изображение в качестве входных данных и предсказывает 8 чисел с плавающей точкой. Вот архитектура модели:
from tensorflow.python.keras.applications.efficientnet import EfficientNetB4 from tensorflow.keras import models, layers def prepare_model_eff(input_shape): conv_base = EfficientNetB4(include_top=False, input_shape=input_shape) conv_base.trainable = True # That's done deliberately! model = models.Sequential() model.add(conv_base) model.add(layers.GlobalMaxPooling2D()) model.add(layers.Dropout(rate=0.2, )) model.add(layers.Dense(8, bias_initializer=Constant(0.0))) return model
Функция потерь: MSE
Показатель: RMSE
Чтобы протестировать архитектуру, я пытаюсь предсказать постоянные значения: 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0.
В наборе поездов ~1500 изображений и ~300 изображений в наборе проверки, так что есть 1800 изображений с одинаковыми постоянными выходами.
Я ожидаю, что модель поймет идею прогнозирования тех же значений и даст RMSE 0,0000000, но удивительно, что после 80 эпох RMSE моделей составляет всего 0,3.
Поскольку тестовое задание очень простое, я полагаю, что с архитектурой может быть что-то не так. Может быть, я стреляю себе в ногу? Вот краткое описание модели для вашей справки:
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= efficientnetb4 (Functional) (None, 12, 12, 1792) 17673823 _________________________________________________________________ global_max_pooling2d (Global (None, 1792) 0 _________________________________________________________________ dropout (Dropout) (None, 1792) 0 _________________________________________________________________ dense (Dense) (None, 8) 14344 ================================================================= Total params: 17,688,167 Trainable params: 17,562,960 Non-trainable params: 125,207