Модель свертки и классификации в основной модели

#python #machine-learning #keras #neural-network

#python #машинное обучение #keras #нейронная сеть

Вопрос:

Я должен создать модель нейронной сети, подобную этой:

 convolution --> classification
                   /
                  /
        _|      |/_
         third model
       with one output
  

Свертка выводит данные, которые используются в качестве входных данных для модели классификации. После этого выходные данные свертки и классификации заполняются (объединяются) в третью модель. Третья модель выведет прогноз 0..1, который используется для обучения всей сети.

  • Прежде всего: возможно ли в этой ситуации корректное обратное распространение модели классификации? Или для этого требуется создать три отдельные модели?
  • Я пытался объединить свертку и классификацию, но без хороших результатов. Я получил ошибку «График отключен».

Полный журнал ошибок: «График отключен: не удается получить значение для тензорного тензора («classification_prediction_Input_2:0», shape=(1, 512), dtype=float32) на уровне «classification_prediction_Input». Доступ к следующим предыдущим слоям был выполнен без проблем: []».

Если идея верна, как подключить модели, подобные «графическим»?

Мой код на данный момент:

 # state convolution
state_input = Input(shape=INPUT_SHAPE, name='state_input', batch_shape=(1, 210, 160, 3))
state_Conv2D_1 = Conv2D(8, kernel_size=(8, 8), strides=(4, 4), activation='relu', name='state_Conv2D_1')(state_input)
state_MaxPooling2D_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='state_MaxPooling2D_1')(state_Conv2D_1)
state_outputs = Flatten(name='state_Flatten')(state_MaxPooling2D_1)
state_convolution_model = Model(state_input, state_outputs, name='state_convolution_model')
state_convolution_model.compile(optimizer='adam', loss='mean_squared_error', metrics=['acc'])

state_convolution_model_input = Input(shape=INPUT_SHAPE, name='state_convolution_model_input', batch_shape=(1, 210, 160, 3))
state_convolution = state_convolution_model(state_convolution_model_input)

# classification output
classficication_Input = Input(shape=(1, LSTM_OUTPUT_DIM), batch_shape=(1, LSTM_OUTPUT_DIM), name='classification_prediction_Input')
classficication_Dense_1 = Dense(32, activation='relu', name='classification_prediction_Dense_1')(classficication_Input)
classficication_output_raw = Dense(ACTIONS, activation='sigmoid', name='classification_output_raw')(classficication_Dense_1)
classficication_output = Reshape((ACTIONS,), name='classification_output')(classficication_output_raw)
classficication_model = Model(classficication_Input, classficication_output, name='classificationPrediction_model')
classficication_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])

classficicationPrediction = classficication_model(state_convolution)

i = keras.layers.concatenate([state_outputs, classficication_output], name='concatenate')
d = Dense(32, activation='relu')(i)
o = Dense(1, activation='sigmoid')(d)
model = Model(state_input, o)                  # <-- graph error is here
plot_model(model, to_file='model.png', show_shapes=True)
  

Комментарии:

1. Что такое LSTM_OUTPUT_DIM, ДЕЙСТВИЯ? Не могли бы вы опубликовать минимальный рабочий пример, чтобы я мог запустить его и воспроизвести ошибку?

2. Целые числа. Может быть 64 (LSTM_OUTPUT_DIM) и 4 (ДЕЙСТВИЯ)

Ответ №1:

Да, вы можете построить подобную структуру и обучать ее сквозным способом. Однако вам необходимо создать единую модель, которая имеет несколько ветвей. Другая проблема, которую я вижу, заключается в том, что вы компилируете модель до того, как она будет полностью определена. Вот рабочий код:

 # state convolution                                                                                                                                                                                                                                                   
state_input = Input(shape=INPUT_SHAPE, name='state_input')
state_Conv2D_1 = Conv2D(8, kernel_size=(8, 8), strides=(4, 4), activation='relu', name='state_Conv2D_1')(state_input)
state_MaxPooling2D_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='state_MaxPooling2D_1')(state_Conv2D_1)
state_outputs = Flatten(name='state_Flatten')(state_MaxPooling2D_1)

# classification output                                                                                                                                                                                                                                               
classification_Dense_1 = Dense(32, activation='relu', name='classification_prediction_Dense_1')(state_outputs)
classification_output_raw = Dense(ACTIONS,                                                                                                                                                                                                                            
                                  activation='sigmoid',                                                                                                                                                                                                               
                                  name='classification_output_raw')(classification_Dense_1)
classification_output = Reshape((ACTIONS,), name='classification_output')(classification_output_raw)


i = concatenate([state_outputs, classification_output], name='concatenate')
d = Dense(32, activation='relu')(i)
o = Dense(1, activation='sigmoid')(d)
model = Model(state_input, o)                  # <-- no graph error anymore here                                                                                                                                                                                      
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
model.summary()
  

Вывод:

 Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
state_input (InputLayer)        (None, 210, 160, 3)  0                                            
__________________________________________________________________________________________________
state_Conv2D_1 (Conv2D)         (None, 51, 39, 8)    1544        state_input[0][0]                
__________________________________________________________________________________________________
state_MaxPooling2D_1 (MaxPoolin (None, 25, 19, 8)    0           state_Conv2D_1[0][0]             
__________________________________________________________________________________________________
state_Flatten (Flatten)         (None, 3800)         0           state_MaxPooling2D_1[0][0]       
__________________________________________________________________________________________________
classification_prediction_Dense (None, 32)           121632      state_Flatten[0][0]              
__________________________________________________________________________________________________
classification_output_raw (Dens (None, 4)            132         classification_prediction_Dense_1
__________________________________________________________________________________________________
classification_output (Reshape) (None, 4)            0           classification_output_raw[0][0]  
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 3804)         0           state_Flatten[0][0]              
                                                                 classification_output[0][0]      
__________________________________________________________________________________________________
dense (Dense)                   (None, 32)           121760      concatenate[0][0]                
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1)            33          dense[0][0]                      
==================================================================================================
  

Дополнительные примеры см. в Руководстве по функциональному API.