Keras load_model вызывает ‘TypeError: аргумент ключевого слова не понят:’ при использовании пользовательского слоя в модели

#python #tensorflow #keras #deep-learning #keras-layer

#python #тензорный поток #keras #глубокое обучение #keras-layer

Вопрос:

Я создаю модель с пользовательским слоем внимания, как это реализовано в учебном пособии Tensorflow по nmt. Я использовал тот же код слоя с несколькими изменениями, которые я нашел в качестве предложений, чтобы решить мою проблему.

Проблема в том, что я не могу загрузить модель из файла после ее сохранения, когда у меня есть этот пользовательский слой. Это класс layer:

 class BahdanauAttention(layers.Layer):
    def __init__(self, output_dim=30, **kwargs):
        super(BahdanauAttention, self).__init__(**kwargs)
        self.W1 = tf.keras.layers.Dense(output_dim)
        self.W2 = tf.keras.layers.Dense(output_dim)
        self.V = tf.keras.layers.Dense(1)

    def call(self, inputs, **kwargs):
        query = inputs[0]
        values = inputs[1]
        query_with_time_axis = tf.expand_dims(query, 1)

        score = self.V(tf.nn.tanh(
            self.W1(query_with_time_axis)   self.W2(values)))

        attention_weights = tf.nn.softmax(score, axis=1)

        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

    def get_config(self):
        config = super(BahdanauAttention, self).get_config()
        config.update({
            'W1': self.W1,
            'W2': self.W2,
            'V': self.V,
        })
        return config
  

Я сохраняю модель с ModelCheckpoint помощью обратного вызова keras:

 path = os.path.join(self.dir, 'model_{}'.format(self.timestamp))
callbacks.append(ModelCheckpoint(path, save_best_only=True, monitor='val_loss', mode='min'))
  

Позже я загружаю модель следующим образом:

  self.model = load_model(path, custom_objects={'BahdanauAttention': BahdanauAttention, 'custom_loss': self.custom_loss})
  

Это сообщение об ошибке, которое я получаю:

 raise TypeError(error_message, kwarg)
    TypeError: ('Keyword argument not understood:', 'W1')
  

и полная трассировка:

 Traceback (most recent call last):
  File "models/lstm.py", line 49, in load_model
    'dollar_mape_loss': self.dollar_mape_loss})
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py", line 187, in load_model
    return saved_model_load.load(filepath, compile, options)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 121, in load
    path, options=options, loader_cls=KerasObjectLoader)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py", line 633, in load_internal
    ckpt_options)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 194, in __init__
    super(KerasObjectLoader, self).__init__(*args, **kwargs)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py", line 130, in __init__
    self._load_all()
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 215, in _load_all
    self._layer_nodes = self._load_layers()
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 315, in _load_layers
    layers[node_id] = self._load_layer(proto.user_object, node_id)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 341, in _load_layer
    obj, setter = self._revive_from_config(proto.identifier, metadata, node_id)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 359, in _revive_from_config
    self._revive_layer_from_config(metadata, node_id))
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 417, in _revive_layer_from_config
    generic_utils.serialize_keras_class_and_config(class_name, config))
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/layers/serialization.py", line 175, in deserialize
    printable_module_name='layer')
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 360, in deserialize_keras_object
    return cls.from_config(cls_config)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 697, in from_config
    return cls(**config)
  File "models/lstm.py", line 310, in __init__
    super(BahdanauAttention, self).__init__(**kwargs)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py", line 457, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 318, in __init__
    generic_utils.validate_kwargs(kwargs, allowed_kwargs)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 778, in validate_kwargs
    raise TypeError(error_message, kwarg)
TypeError: ('Keyword argument not understood:', 'W1')
  

Подобные вопросы предполагают, что в коде используются разные версии Keras и TensorFlow. Я использую только Keras от TensorFlow. Это импорт

 from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint
from tensorflow.keras import layers
  

Ответ №1:

Следуя документации keras по пользовательским слоям, они рекомендуют, чтобы любые веса инициализировались не in __init__() , а in build() . Таким образом, веса не нужно добавлять в конфигурацию, и ошибка будет устранена.

Это обновленный класс пользовательского слоя:

 class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units=30, **kwargs):
        super(BahdanauAttention, self).__init__(**kwargs)
        self.units = units
      

    def build(self, input_shape):
        self.W1 = tf.keras.layers.Dense(self.units)
        self.W2 = tf.keras.layers.Dense(self.units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, inputs, **kwargs):
        query = inputs[0]
        values = inputs[1]
        query_with_time_axis = tf.expand_dims(query, 1)

       
        score = self.V(tf.nn.tanh(
            self.W1(query_with_time_axis)   self.W2(values)))

        attention_weights = tf.nn.softmax(score, axis=1)

        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

    def get_config(self):
        config = super(BahdanauAttention, self).get_config()
        config.update({
            'units': self.units,
        })
        return config
  

Ответ №2:

У меня тоже есть эта проблема. Я перепробовал много методов и обнаружил, что этот метод можно использовать. сначала создайте модель

 model = TextAttBiRNN(maxlen, max_features, embedding_dims).get_model()
model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
  

во-вторых, веса нагрузки:
Я решил проблему с помощью этого:

 model_file = "/content/drive/My Drive/dga/output_data/model_lstm_att_test_v6.h5"
model.load_weights(model_file)
  

затем мы обнаружим, что модель может быть использована.

таким образом, я избежал предыдущих вопросов.