#tensorflow #machine-learning #keras #serialization #persistence
Вопрос:
Когда я создаю модель Keras с одним или несколькими пользовательскими слоями, я могу использовать model.save()
метод сохранения модели Keras с использованием формата TensorFlow SavedModel.
Я могу загрузить эту модель из файловой системы с помощью tf.keras.models.load_model()
функции и снова сохранить ее в файловой системе.
Но когда я загружаю сохраненную модель из файловой системы во второй раз, она завершается с этим исключением:
TypeError: f(inputs, training, training, training, training, *, training, training) missing 1 required argument: training
Вы можете попробовать воспроизвести эту проблему со следующим кодом:
import tensorflow as tf
class CustomLayer(tf.keras.layers.Layer):
def call(self, inputs, *args, **kwargs):
return inputs
model1 = tf.keras.Sequential([
CustomLayer()
])
model1.build((None, 1))
model1.compile()
model1.save("model1")
model2 = tf.keras.models.load_model("model1")
model2.save("model2")
# This line should raise a TypeError.
model3 = tf.keras.models.load_model("model2")
Ответ №1:
Почему существует эта проблема
Проблема в том, что формат TensorFlow SavedModel фактически не сериализует пользовательский код Python. Он сохраняет только график тензорного потока, сгенерированный пользовательскими слоями Keras и другими объектами Python.
tf.keras.models.load_model()
Функция-по умолчанию-не возвращает слой Python. Вместо этого он возвращает слой-заполнитель, содержащий ту же часть графика вычислений TensorFlow. Мы можем видеть это в примере в моем вопросе:
>>> model1.layers
[<__main__.CustomLayer at 0x7ff04c14ee20>]
>>> model2.layers
[<keras.saving.saved_model.load.CustomLayer at 0x7ff114fd7be0>]
При model2
сохранении и загрузке из файловой системы TensorFlow не может правильно проанализировать аргументы *args
и **kwargs
CustomLayer.call()
.
Я не знаю, находится ли фактическая ошибка в коде сохранения, коде загрузки или в обоих.
Реальное исправление должно произойти в TensorFlow/Keras, но в то же время существуют
Обходные пути
Вы можете выбрать любой из приведенных ниже обходных путей, чтобы избежать ошибок сериализации с помощью пользовательских слоев Keras.
Измените подпись на Layer.call()
В настоящее время официальная подпись метода на Layer.call()
def call(self, inputs, *args, **kwargs):
Но TensorFlow выдаст ошибку типа при попытке загрузить модель с пользовательским слоем с этой подписью. Чтобы исправить ошибку, напишите все ваши пользовательские слои с подписью def call(self, inputs):
. Если ваш слой ведет себя по-другому во время обучения или вывода, вы можете использовать сигнатуру метода def call(self, inputs, training=None):
Это облегчает TensorFlow создание слоев-заполнителей, созданных в keras.saving.saved_model.load
модуле. Но этот слой заполнителей все еще не совсем совпадает с исходным кодом Python.
Используйте custom_objects
параметр на tf.keras.models.load_model()
Можно загрузить модель с ее исходными слоями Python вместо слоев-заполнителей. Просто передайте словарь, сопоставляющий имена слоев объектам класса слоев Python. Для этого требуется, чтобы ваш код мог импортировать исходный слой Python. Пример в моем вопросе можно исправить следующим образом:
model3 = tf.keras.models.load_model(
"model2",
custom_objects=dict(
CustomLayer=CustomLayer,
),
)
Убедитесь, что ваш слой реализует Layer.get_config()
и возвращает словарь со всеми параметрами, необходимыми для воссоздания слоя с нуля. Слой должен быть в состоянии быть воссоздан с Layer.from_config()
помощью .
Импортируйте слой Python и добавьте его в глобальный реестр Keras
Keras поддерживает глобальный реестр пользовательских классов Python и других объектов, на которые можно ссылаться при загрузке сохраненных моделей. Вы можете зарегистрировать свой пользовательский слой Кераса у @tf.keras.utils.register_keras_serializable()
декоратора. Например:
@tf.keras.utils.register_keras_serializable(
package="my_python_package"
)
class CustomLayer(tf.keras.layers.Layer):
def call(self, inputs, *args, **kwargs):
return inputs
Этот метод также требует, чтобы ваш слой был правильно реализован Layer.get_config()
.
Установите объект слоя Python с помощью tf.keras.utils.custom_object_scope()
Как и в двух предыдущих решениях, tf.keras.utils.custom_object_scope()
контекстный менеджер может указать, какие пользовательские слои использовать при десериализации.