Обрезка модели тензорного потока дает » nan » для потерь при обучении и проверке

#python #tensorflow #keras #deep-learning #model

Вопрос:

Я пытаюсь обрезать базовую модель, состоящую из нескольких слоев поверх сети VGG. Он также содержит определенный пользователем слой с именем instance_normalization . Для успешной обрезки я определил get_prunable_weights функцию этого слоя следующим образом:

 ### defined for model pruning
    def get_prunable_weights(self):
        return self.weights
 

Я использовал следующую функцию для получения структуры модели, подлежащей сокращению, с использованием базовой модели с именем model :

 def define_prune_model(self, model, img_shape, epochs, batch_size, validation_split=0.1):
        num_images = img_shape[0] * (1 - validation_split)
        end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

        # Define model for pruning.
        pruning_params = {
            'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.5,
                                                                    final_sparsity=0.80,
                                                                    begin_step=0,
                                                                    end_step=end_step)
        }

        model_for_pruning = prune_low_magnitude(model, **pruning_params)

        model_for_pruning.compile(optimizer='adam',
                    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])

        model_for_pruning.summary()

        return model_for_pruning
 

Затем я написал следующую функцию для выполнения обучения на этой модели обрезки:

 def train_prune_model(self, model_for_pruning, train_images, train_labels,
                     epochs, batch_size, validation_split=0.1):
    callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir='./models/pruned'),
    ]
    model_for_pruning.fit(train_images, train_labels,
                batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                callbacks=callbacks)
    return model_for_pruning
 

Однако при обучении я обнаружил, что потери при обучении и проверке были все nan , а итоговый результат прогнозирования модели был полностью равен нулю. Однако базовая модель, к которой перешел define_prune_model , успешно обучена и правильно предсказана.

Как я могу это решить? Заранее спасибо.

Ответ №1:

Трудно точно определить проблему без дополнительной информации. В частности, не могли бы вы предоставить более подробную информацию (желательно в виде кода) о вашем пользовательском instance_normalization слое ?

Предполагая, что код в порядке: поскольку вы упомянули, что модель правильно обучается без обрезки, может ли быть так, что эти параметры обрезки слишком жесткие ? В конце концов, эти параметры устанавливают 50% весовые коэффициенты равными нулю прямо с первого шага обучения.

Вот что я бы попробовал:

  • Поэкспериментируйте с более низким уровнем разреженности (особенно initial_sparsity ).
  • Начните применять обрезку позже во время тренировки ( begin_step аргумент графика обрезки). Некоторые даже предпочитают обучать модель один раз, вообще не применяя обрезку. Затем снова тренируйтесь с prune_low_magnitude() помощью .
  • Обрезайте только на некоторых этапах, давая модели время для восстановления между обрезками ( frequency аргумент).
  • Наконец, если он все равно не сработает, обычные методы лечения при возникновении потерь nan: уменьшите скорость обучения, используйте регуляризацию или отсечение градиента …