TensorFlow Keras SavedModel выдает ошибку типа после сохранения и загрузки дважды

#tensorflow #machine-learning #keras #serialization #persistence

Вопрос:

Когда я создаю модель Keras с одним или несколькими пользовательскими слоями, я могу использовать model.save() метод сохранения модели Keras с использованием формата TensorFlow SavedModel.


Я могу загрузить эту модель из файловой системы с помощью tf.keras.models.load_model() функции и снова сохранить ее в файловой системе.

Но когда я загружаю сохраненную модель из файловой системы во второй раз, она завершается с этим исключением:

 TypeError: f(inputs, training, training, training, training, *, training, training) missing 1 required argument: training
 

Вы можете попробовать воспроизвести эту проблему со следующим кодом:

 import tensorflow as tf

class CustomLayer(tf.keras.layers.Layer):
    def call(self, inputs, *args, **kwargs):
        return inputs

model1 = tf.keras.Sequential([
    CustomLayer()
])
model1.build((None, 1))
model1.compile()
model1.save("model1")

model2 = tf.keras.models.load_model("model1")
model2.save("model2")

# This line should raise a TypeError.
model3 = tf.keras.models.load_model("model2")
 

Ответ №1:

Почему существует эта проблема

Проблема в том, что формат TensorFlow SavedModel фактически не сериализует пользовательский код Python. Он сохраняет только график тензорного потока, сгенерированный пользовательскими слоями Keras и другими объектами Python.

tf.keras.models.load_model() Функция-по умолчанию-не возвращает слой Python. Вместо этого он возвращает слой-заполнитель, содержащий ту же часть графика вычислений TensorFlow. Мы можем видеть это в примере в моем вопросе:

 >>> model1.layers
[<__main__.CustomLayer at 0x7ff04c14ee20>]

>>> model2.layers
[<keras.saving.saved_model.load.CustomLayer at 0x7ff114fd7be0>]
 

При model2 сохранении и загрузке из файловой системы TensorFlow не может правильно проанализировать аргументы *args и **kwargs CustomLayer.call() .

Я не знаю, находится ли фактическая ошибка в коде сохранения, коде загрузки или в обоих.

Реальное исправление должно произойти в TensorFlow/Keras, но в то же время существуют

Обходные пути

Вы можете выбрать любой из приведенных ниже обходных путей, чтобы избежать ошибок сериализации с помощью пользовательских слоев Keras.

Измените подпись на Layer.call()

В настоящее время официальная подпись метода на Layer.call() def call(self, inputs, *args, **kwargs):

Но TensorFlow выдаст ошибку типа при попытке загрузить модель с пользовательским слоем с этой подписью. Чтобы исправить ошибку, напишите все ваши пользовательские слои с подписью def call(self, inputs): . Если ваш слой ведет себя по-другому во время обучения или вывода, вы можете использовать сигнатуру метода def call(self, inputs, training=None):

Это облегчает TensorFlow создание слоев-заполнителей, созданных в keras.saving.saved_model.load модуле. Но этот слой заполнителей все еще не совсем совпадает с исходным кодом Python.

Используйте custom_objects параметр на tf.keras.models.load_model()

Можно загрузить модель с ее исходными слоями Python вместо слоев-заполнителей. Просто передайте словарь, сопоставляющий имена слоев объектам класса слоев Python. Для этого требуется, чтобы ваш код мог импортировать исходный слой Python. Пример в моем вопросе можно исправить следующим образом:

 model3 = tf.keras.models.load_model(
    "model2",
    custom_objects=dict(
        CustomLayer=CustomLayer,
    ),
)
 

Убедитесь, что ваш слой реализует Layer.get_config() и возвращает словарь со всеми параметрами, необходимыми для воссоздания слоя с нуля. Слой должен быть в состоянии быть воссоздан с Layer.from_config() помощью .

Импортируйте слой Python и добавьте его в глобальный реестр Keras

Keras поддерживает глобальный реестр пользовательских классов Python и других объектов, на которые можно ссылаться при загрузке сохраненных моделей. Вы можете зарегистрировать свой пользовательский слой Кераса у @tf.keras.utils.register_keras_serializable() декоратора. Например:

 @tf.keras.utils.register_keras_serializable(
   package="my_python_package"
)
class CustomLayer(tf.keras.layers.Layer):
    def call(self, inputs, *args, **kwargs):
        return inputs
 

Этот метод также требует, чтобы ваш слой был правильно реализован Layer.get_config() .

Установите объект слоя Python с помощью tf.keras.utils.custom_object_scope()

Как и в двух предыдущих решениях, tf.keras.utils.custom_object_scope() контекстный менеджер может указать, какие пользовательские слои использовать при десериализации.