#python #tensorflow #tf.keras #gradienttape
#python #тензорный поток #tf.keras #градиентная лента
Вопрос:
Я работал над моделью, цикл обучения которой использует оболочку tf.function (я получаю ошибки ООМ при быстром запуске), и обучение, похоже, проходит нормально. Однако я не могу получить доступ к значениям тензора, возвращаемым моей пользовательской обучающей функцией (ниже)
def train_step(inputs, target):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
curr_loss = lovasz_softmax_flat(predictions, target)
gradients = tape.gradient(curr_loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))
# Need to access this value
return curr_loss
Упрощенная версия моего «зонтичного» цикла обучения выглядит следующим образом:
@tf.function
def train_loop():
for epoch in range(EPOCHS):
for tr_file in train_files:
tr_inputs = preprocess(tr_file)
tr_loss = train_step(tr_inputs, target)
print(tr_loss.numpy())
Когда я пытаюсь распечатать значение потери, я получаю следующую ошибку:
AttributeError: объект ‘Tensor’ не имеет атрибута ‘numpy’
Я также попытался использовать tf.print() следующим образом:
tf.print("Loss: ", tr_loss, output_stream=sys.stdout)
Но, похоже, на терминале ничего не отображается. Есть предложения?
Ответ №1:
Вы не можете преобразовать в массив Numpy в графическом режиме. Просто создайте tf.metrics
объект вне функции и обновите его в функции.
mean_loss_values = tf.metrics.Mean()
def train_step(inputs, target):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
curr_loss = lovasz_softmax_flat(predictions, target)
gradients = tape.gradient(curr_loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))
# look below
mean_loss_values(curr_loss)
# or mean_loss_values.update_state(curr_loss)
# Need to access this value
return curr_loss
Затем позже в вашем коде:
mean_loss_values.result()
Комментарии:
1. Спасибо за помощь, Николас. Но для моего случая необходимо, чтобы метрики оценивались в функции train_loop() (tf.function), поскольку значения должны регулярно обновляться с помощью tf.summary writer для оценки Tensorboard. Кроме того, из-за сложности моей модели я не могу быстро запускать функции train_loop() или train_step().