tf.keras.Точность обучения модели скачет после запуска Model.evaluate() во время обучения

#tensorflow #keras #tf.keras

#tensorflow #keras #tf.keras

Вопрос:

Я Model использую Model.fit() TensorFlow Keras. Я также использую обратные вызовы для регистрации моих показателей точности обучения после каждого пакета, используя on_train_batch_end() синтаксис TensorFlow. Кроме того, я использую другой обратный вызов для запуска Model.evaluate() каждые 1000 пакетов, чтобы вычислить точность набора проверки и обновить logs dict, переданный во время обратных вызовов Model.fit() .

Просмотр зарегистрированных показателей в зависимости от номера партии показывает очень странные результаты. После Model.evaluate() запуска точность обучения испытывает значительный «толчок», первоначально вызывающий быстрое увеличение точности зарегистрированного обучения и впоследствии вызывающий значительное снижение точности обучения с последующим более медленным восстановлением (см. Прикрепленные изображения).

Я предполагаю, что это как-то связано с вызовом Model.evaluate() reset_metrics() , который выполняет цикл и вызывает метод reset_states() для каждой метрики. Я не могу понять, что reset_states() делает, и имеет ли это отношение к поведению, которое я наблюдаю. Похоже, это относится к среднему родительскому классу CategoricalAccuracy . Я пока не смог найти ничего полезного в документах TensorFlow.

Действительно ли показатели, показанные во время Model.fit() , представляют собой какую-то форму скользящих средних, а не пакетную метрику? В этом случае reset_states() методом будет сброс скользящего среднего, что, возможно, приведет к тряске.

Может ли кто-нибудь, кто лучше разбирается во внутренней работе TensorFlow, помочь?

Точность встряхивания.

Точность встряхивания (увеличена).

Комментарии:

1. Вы уже проверили, может ли ваша промежуточная оценка привести к нежелательному накоплению потерь или градиентов, которые затем (возможно) могут быть применены во время первого обновления веса после оценки?

2. Спасибо @DanielB. Я, конечно, обеспокоен тем, что вызов Model.evaluate() автоматически обучается моим данным проверки. Это моя другая гипотеза. Однако, если true, я бы ожидал более высокой точности по сравнению с набором проверки, чем я наблюдаю. Кроме того, Model.evaluate() предполагается перевести сеть в «тестовый режим», что, надеюсь , означает, что градиенты не вычисляются. Приветствуются мысли.