#tensorflow #keras #tf.keras
#tensorflow #keras #tf.keras
Вопрос:
Я Model
использую Model.fit()
TensorFlow Keras. Я также использую обратные вызовы для регистрации моих показателей точности обучения после каждого пакета, используя on_train_batch_end()
синтаксис TensorFlow. Кроме того, я использую другой обратный вызов для запуска Model.evaluate()
каждые 1000 пакетов, чтобы вычислить точность набора проверки и обновить logs
dict, переданный во время обратных вызовов Model.fit()
.
Просмотр зарегистрированных показателей в зависимости от номера партии показывает очень странные результаты. После Model.evaluate()
запуска точность обучения испытывает значительный «толчок», первоначально вызывающий быстрое увеличение точности зарегистрированного обучения и впоследствии вызывающий значительное снижение точности обучения с последующим более медленным восстановлением (см. Прикрепленные изображения).
Я предполагаю, что это как-то связано с вызовом Model.evaluate() reset_metrics()
, который выполняет цикл и вызывает метод reset_states() для каждой метрики. Я не могу понять, что reset_states()
делает, и имеет ли это отношение к поведению, которое я наблюдаю. Похоже, это относится к среднему родительскому классу CategoricalAccuracy
. Я пока не смог найти ничего полезного в документах TensorFlow.
Действительно ли показатели, показанные во время Model.fit()
, представляют собой какую-то форму скользящих средних, а не пакетную метрику? В этом случае reset_states()
методом будет сброс скользящего среднего, что, возможно, приведет к тряске.
Может ли кто-нибудь, кто лучше разбирается во внутренней работе TensorFlow, помочь?
Комментарии:
1. Вы уже проверили, может ли ваша промежуточная оценка привести к нежелательному накоплению потерь или градиентов, которые затем (возможно) могут быть применены во время первого обновления веса после оценки?
2. Спасибо @DanielB. Я, конечно, обеспокоен тем, что вызов
Model.evaluate()
автоматически обучается моим данным проверки. Это моя другая гипотеза. Однако, если true, я бы ожидал более высокой точности по сравнению с набором проверки, чем я наблюдаю. Кроме того,Model.evaluate()
предполагается перевести сеть в «тестовый режим», что, надеюсь , означает, что градиенты не вычисляются. Приветствуются мысли.