Сохранение выходных данных метрик тензорного потока во время обучения в режиме графика

#python #tensorflow #keras #deep-learning

Вопрос:

У меня есть пользовательский цикл обучения в Tensorflow (версия 2.3), и я отслеживаю SSIM, используя пользовательский tf.keras.Metric . Мой код выглядит так:

 @tf.function
def train():
    # Build model
    net = FullModel()

    # Build metric
    train_ssim = MS_SSIM()

    # Build dataset (this comes out as a tf.data.Dataset of images)
    train_data = build_dataset()

    for epoch in range(100):
        for i, batch in train_data.enumerate():
            # Use autoencoder
            reconstructed = net(batch)
            # Compute metrics
            train_ssim.update_state(batch, reconstructed)
        
        # Write metrics to CSV logfile
        with open('training_log.csv', 'a') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([epoch, float(train_ssim.result())])
 

Когда я запускаю этот код, результирующий CSV-файл говорит что-то вроде: <tf.Tensor 'Identity:0' shape=() dtype=float32> для результата SSIM. Мой вопрос: как мне извлечь фактическое значение с плавающей точкой этого тензора, чтобы я мог сохранить его в формате CSV?

Этот код выполняется в графическом режиме, поэтому я НЕ пытаюсь просто включить нетерпеливое выполнение, если только я не могу каким-то образом просто включить его, чтобы получить этот результат, а затем снова отключить, но я не думаю, что это возможно.

Я пробовал использовать tf.keras.backend.get_value() и tf.get_static_value() , но первый выдает мне ошибку «Объект тензора не имеет атрибута _numpy ()», а второй просто не возвращает его.

Как вы можете извлечь эти значения в Tensorflow 2.3 в режиме графика, чтобы сохранить их в CSV-файле в пользовательском цикле обучения?