Пользовательская функция потерь: Как добавить выходные данные скрытого слоя в функцию потерь в keras с помощью Tensorflow

#tensorflow #keras #layer #loss-function #loss

#tensorflow #keras #слой #функция потерь #потеря

Вопрос:

В моей модели выходные данные скрытого слоя, а именно «закодированный», имеют два канала (например, форма: [нет, 128, 128, 2]). Я надеюсь добавить SSIM между этими двумя каналами в функцию потерь:

потеря = ssim (ввод, вывод) тета * ssim (закодированный (канал1), закодированный (канал2)).

Как я мог бы это реализовать? Ниже приведена архитектура моей модели.

 def structural_similarity_index(y_true, y_pred):
    loss = 1 - tf.image.ssim(y_true, y_pred, max_val=1.0) 
    return loss

def mymodel():
    input_img = Input(shape=(256, 256, 1))

    # encoder
    x = Conv2D(4, (3, 3), activation='relu', padding='same')(input_img)
    x = MaxPooling2D((2, 2), padding='same')(x)
    encoded = Conv2D(2, (3, 3), activation='relu', padding='same', name='encoder')(x)

    # decoder    
    x = Conv2D(4, (3, 3), activation='relu', padding='same')(encoded)
    x = UpSampling2D((2, 2))(x)
    decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

    autoencoder = Model(input_img, decoded)    
    autoencoder.compile(optimizer = 'adadelta', loss = structural_similarity_index)
    autoencoder.summary()        
    return autoencoder
  

Я попытался определить функцию ‘loss_warper’, как показано ниже, но это не сработало. Вот как я добавил эту функцию потерь:

 autoencoder.add_loss(loss_wrapper(encoded[:,:,:,0],encoded[:,:,:,1])(input_img, decoded))
  

функция ‘loss_warper’:

 def loss_wrapper(CH1, CH2):
    def structural_similarity_index(y_true, y_pred):
        regweight = 0.01
        loss = 1 - tf.image.ssim(y_true, y_pred, max_val=1.0)
        loss = loss   regweight*(1-tf.image.ssim(CH1, CH2, max_val=1.0))
        return loss
    return structural_similarity_index
  

Сообщение об ошибке:

 File "E:/Autoencoder.py", line 160, in trainprocess
    validation_data= (x_validate, x_validate))
...
ValueError: ('Error when checking model target: expected no data, but got:', array([...]...[...]))
  

Кто-нибудь знает, как это реализовать? Любая помощь высоко ценится!