#python #tensorflow #machine-learning
Вопрос:
Я пытаюсь использовать @tf.function (https://www.tensorflow.org/guide/function#loops) для повышения производительности, но я сталкиваюсь с некоторыми проблемами при преобразовании цикла обучения из нетерпеливого выполнения в автограф. Поскольку моя проблема заключается в том, чтобы найти общий синтаксис, и он не связан с конкретным блоком кода, я напишу общий пользовательский цикл обучения для эпохи в TensorFlow:
@tf.function
def train_epoch(dataset):
for b, (data, label) in enumerate(dataset):
with tf.GradientTape() as tape:
logits = model(data, training=True)
loss_value = loss_fn(label, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
print(f'step {b}/{train_steps} Loss value: {loss_value}')
tf.saved_model.save(model, save_path)
Мои проблемы заключаются в следующем:
при работе for b, (data, label) in enumerate(dataset)
в режиме @tf.function я не могу понять поведение цикла. Я попробовал также с for b, (data, label) in dataset.enumerate()
и. for data, label in dataset
Поскольку в режиме @tf.function последняя строка ( print(f'step {b}/{train_steps} Loss value: {loss_value}'
) выполняется только во время трассировки, я не могу следить за тем, правильно ли выполняется код.
Еще одна проблема, с которой я столкнулся, заключается в том, что сохранение модели с момента последней строки ( tf.saved_model.save(model, save_path
) выдает мне ошибку при выполнении в режиме @tf.function.
Вкратце мои вопросы по режиму @tf.function следующие:
- Каковы рекомендации по написанию пользовательского цикла обучения и мониторингу его состояния?
- Какова предпочтительная точка входа @tf.function (вне основного цикла поезда?, в
__call__
функциях подклассных моделей?) - Как сохранить и загрузить пользовательские модели подклассов?
Комментарии:
1. Для отладки вашего обучения предпочтительнее сделать это в режиме ожидания, прежде чем запускать режим графика в @tf.function . Для печати в режиме графика я думаю, что «tf.print» работает вместо «печать».