я строю свою модель обучения передаче keras, как показано ниже, Одна вещь, которую я не могу сделать, это установить training=False для слоев BatchNorm в xception

#tensorflow #keras #transfer-learning #batch-normalization

Вопрос:

Как вы можете видеть, я не хочу менять способ, которым я построил эту модель, есть способ изменить, но это преобразует модель xception в некоторую функциональную модель, и в сводке модели она просто показывает Xception вместо всех ее слоев, а также я не могу применить градуированную камеру. Так что, пожалуйста, помогите кому-нибудь.

код

 def build_model():
    # use imagenet - pre-trainined weights for images
    baseModel =Xception(weights= 'imagenet', include_top = False, input_shape=(224, 224, 3))
    for layer in baseModel.layers:
      layer.trainable = False
      bn_layer.trainable=False
   
    headModel =baseModel.output 
    headModel = Flatten()(headModel)
    headModel = Dense(64,activation="LeakyReLU")(headModel)
    headModel = Dropout(0.5)(headModel)
    headModel = Dense(32,activation="LeakyReLU")(headModel)
    headModel = Dropout(0.4)(headModel)
    headModel = Dense(16, activation="LeakyReLU")(headModel)
    headModel = Dropout(0.3)(headModel)
    headModel = Dense(8, activation="LeakyReLU")(headModel)
    headModel = Dropout(0.2)(headModel)
    headModel = Dense(3, activation="softmax")(headModel)
    
    x = Model(baseModel.inputs,outputs=headModel)

    optimizers = Adam(learning_rate=  0.001)
    x.compile(loss = 'categorical_crossentropy', optimizer = optimizers, metrics = ['accuracy'])
    return x

x= build_model()
sum=x.summary()
 

Ответ №1:

Насколько я понял, вы хотите заморозить весь слой в базовой модели, верно? Вы можете сделать следующее:

     for layer in baseModel.layers:
            layer.trainable = False
 

В этом случае ваш код верен, за исключением того, что вам не нужна вторая строка.

Вместо этого, если вы хотите обучить все слои, кроме слоев BatchNormalization, вы можете сделать следующее:

     for layer in model.layers:
            if not isinstance(layer, layers.BatchNormalization):
                    layer.trainable = True
 

Для получения дополнительной информации прочитайте: https://keras.io/guides/transfer_learning/ Особенно в разделе Пример: уровень нормализации пакетов.