#python #tensorflow #keras #deep-learning
Вопрос:
У меня есть пользовательский цикл обучения в Tensorflow (версия 2.3), и я отслеживаю SSIM, используя пользовательский tf.keras.Metric
. Мой код выглядит так:
@tf.function
def train():
# Build model
net = FullModel()
# Build metric
train_ssim = MS_SSIM()
# Build dataset (this comes out as a tf.data.Dataset of images)
train_data = build_dataset()
for epoch in range(100):
for i, batch in train_data.enumerate():
# Use autoencoder
reconstructed = net(batch)
# Compute metrics
train_ssim.update_state(batch, reconstructed)
# Write metrics to CSV logfile
with open('training_log.csv', 'a') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([epoch, float(train_ssim.result())])
Когда я запускаю этот код, результирующий CSV-файл говорит что-то вроде: <tf.Tensor 'Identity:0' shape=() dtype=float32>
для результата SSIM. Мой вопрос: как мне извлечь фактическое значение с плавающей точкой этого тензора, чтобы я мог сохранить его в формате CSV?
Этот код выполняется в графическом режиме, поэтому я НЕ пытаюсь просто включить нетерпеливое выполнение, если только я не могу каким-то образом просто включить его, чтобы получить этот результат, а затем снова отключить, но я не думаю, что это возможно.
Я пробовал использовать tf.keras.backend.get_value()
и tf.get_static_value()
, но первый выдает мне ошибку «Объект тензора не имеет атрибута _numpy ()», а второй просто не возвращает его.
Как вы можете извлечь эти значения в Tensorflow 2.3 в режиме графика, чтобы сохранить их в CSV-файле в пользовательском цикле обучения?