#python #keras #deep-learning #tensorflow2.0 #tensorboard
#python #keras #глубокое обучение #tensorflow2.0 #tensorboard
Вопрос:
Когда я обучаю свою модель с помощью TensorFlow 2.3, я хочу визуализировать некоторые промежуточные тензоры, вычисленные с использованием веса в графе вычислений моего настроенного tf.keras.layers.Layer
.
Поэтому я использую tf.summary.image()
для записи этих тензоров и визуализации их в виде таких изображений:
class CustomizedLayer(tf.keras.layers.Layer):
def call(self, inputs, training=None):
# ... some code ...
tf.summary.image(name="some_weight_map", data=some_weight_map)
# ... some code ...
Но в TensorBoard, независимо от того, сколько пройдено шагов, отображается только одно изображение шага 0.
И я попытался установить параметр step of tf.summary.image()
на значение, полученное из tf.summary.experimental.get_step()
:
tf.summary.image(name="weight_map", data=weight_map, step=tf.summary.experimental.get_step())
И обновите шаг, вызвав tf.summary.experimental.set_step из настроенного обратного вызова, используя tf.Variable, подобные кодам, показанным ниже:
class SummaryCallback(tf.keras.callbacks.Callback):
def __init__(self, step_per_epoch):
super().__init__()
self.global_step = tf.Variable(initial_value=0, trainable=False, name="global_step")
self.global_epoch = 0
self.step_per_epoch = step_per_epoch
tf.summary.experimental.set_step(self.global_step)
def on_batch_end(self, batch, logs=None):
self.global_step = batch self.step_per_epoch * self.global_epoch
tf.summary.experimental.set_step(self.global_step)
# whether the line above is commented, calling tf.summary.experimental.get_step() in computation graph code always returns 0.
# tf.print(self.global_step)
def on_epoch_end(self, epoch, logs=None):
self.global_epoch = 1
Экземпляр этого обратного вызова передается в обратных вызовах аргументов в model.fit()
функции.
Но tf.summary.experimental.get_step()
возвращаемое значение по-прежнему равно 0.
В документе TensorFlow из « tf.summary.experimental.set_step()
» говорится:
при использовании этого с @tf.functions значение шага будет записано во время трассировки функции, поэтому изменения шага вне функции не будут отражены внутри функции, если не использовать шаг tf.Variable .
Согласно документу, я уже использую переменную для хранения шагов, но ее изменения по-прежнему не отражаются внутри функции (или keras.Модель).
Примечание: мой код выдает ожидаемые результаты в TensorFlow 1.x с помощью простой строки tf.summary.image()
, прежде чем я перенесу его в TensorFlow 2.
Итак, я хочу знать, неверен ли мой подход в TensorFlow 2?
В TF2, как я могу получить шаги обучения внутри графа вычислений?
Или есть другое решение для суммирования тензоров (как скалярных, изображений и т. Д.) Внутри модели в TensorFlow 2?
Ответ №1:
Я обнаружил, что об этой проблеме сообщалось в репозитории Tensorflow на Github: https://github.com/tensorflow/tensorflow/issues/43568
Это вызвано использованием tf.summary в модели во время обратных вызовов tf.keras.Обратный вызов TensorBoard также включен, и шаг всегда будет равен нулю. Отчет о проблемах дает временное решение.
Чтобы исправить это, унаследуйте tf.keras.callbacks .Класс TensorBoard и перезапишите метод on_train_begin и метод on_test_begin следующим образом:
class TensorBoardFix(tf.keras.callbacks.TensorBoard):
"""
This fixes incorrect step values when using the TensorBoard callback with custom summary ops
"""
def on_train_begin(self, *args, **kwargs):
super(TensorBoardFix, self).on_train_begin(*args, **kwargs)
tf.summary.experimental.set_step(self._train_step)
def on_test_begin(self, *args, **kwargs):
super(TensorBoardFix, self).on_test_begin(*args, **kwargs)
tf.summary.experimental.set_step(self._val_step)
И используйте этот фиксированный класс обратного вызова в model.fit():
tensorboard_callback = TensorBoardFix(log_dir=log_dir, histogram_freq=1, write_graph=True, update_freq=1)
model.fit(dataset, epochs=200, callbacks=[tensorboard_callback])
Это решило мою проблему, и теперь я могу получить правильный шаг внутри моей модели, вызвав tf.summary.experimental.get_step() .
(Эта проблема может быть исправлена в более поздней версии TensorFlow)