Как использовать пользовательский цикл обучения и пользовательскую модель в TensorFlow при работе с @tf.функцией и автографом?

#python #tensorflow #machine-learning

Вопрос:

Я пытаюсь использовать @tf.function (https://www.tensorflow.org/guide/function#loops) для повышения производительности, но я сталкиваюсь с некоторыми проблемами при преобразовании цикла обучения из нетерпеливого выполнения в автограф. Поскольку моя проблема заключается в том, чтобы найти общий синтаксис, и он не связан с конкретным блоком кода, я напишу общий пользовательский цикл обучения для эпохи в TensorFlow:

      @tf.function
     def train_epoch(dataset):
        for b, (data, label) in enumerate(dataset):
            with tf.GradientTape() as tape:
                logits = model(data, training=True)
                loss_value = loss_fn(label, logits)
            grads = tape.gradient(loss_value, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights))
            print(f'step {b}/{train_steps} Loss value: {loss_value}')
            tf.saved_model.save(model, save_path)
 

Мои проблемы заключаются в следующем:

при работе for b, (data, label) in enumerate(dataset) в режиме @tf.function я не могу понять поведение цикла. Я попробовал также с for b, (data, label) in dataset.enumerate() и. for data, label in dataset Поскольку в режиме @tf.function последняя строка ( print(f'step {b}/{train_steps} Loss value: {loss_value}' ) выполняется только во время трассировки, я не могу следить за тем, правильно ли выполняется код.

Еще одна проблема, с которой я столкнулся, заключается в том, что сохранение модели с момента последней строки ( tf.saved_model.save(model, save_path ) выдает мне ошибку при выполнении в режиме @tf.function.

Вкратце мои вопросы по режиму @tf.function следующие:

  1. Каковы рекомендации по написанию пользовательского цикла обучения и мониторингу его состояния?
  2. Какова предпочтительная точка входа @tf.function (вне основного цикла поезда?, в __call__ функциях подклассных моделей?)
  3. Как сохранить и загрузить пользовательские модели подклассов?

Комментарии:

1. Для отладки вашего обучения предпочтительнее сделать это в режиме ожидания, прежде чем запускать режим графика в @tf.function . Для печати в режиме графика я думаю, что «tf.print» работает вместо «печать».