Невозможно загрузить обученную контрольную точку

#python #tensorflow

#python #тензорный поток

Вопрос:

Я следую коду отсюда, чтобы изучить задачу обобщения текста с помощью модели transformer, ее можно найти здесь

Но код не предоставляет способа загрузки модели после обучения, так что это неудобство, и я решил написать эту функцию

Вот мой код:

 model = Transformer(
num_layers, 
d_model, 
num_heads, 
dff,
encoder_vocab_size, 
decoder_vocab_size, 
pe_input=max_len_news,
pe_target=max_len_summary,
)

model.load_weights('checkpoints/ckpt-5.data-00000-of-00001') 
  

Выдает ошибку:

 ValueError: Unable to load weights saved in HDF5 format into a subclassed Model which has not created its variables yet. Call the Model first, then load the weights.
  

Я совершенно новичок в машинном обучении и TensorFlow. Я знаю, что он пытается сказать, но я просто не знаю, как исправить эту проблему, пожалуйста, помогите.

Ответ №1:

Перед загрузкой весов необходимо вызвать модель с фиктивным вводом.

Попробуйте это:

 model = Transformer(
num_layers, 
d_model, 
num_heads, 
dff,
encoder_vocab_size, 
decoder_vocab_size, 
pe_input=max_len_news,
pe_target=max_len_summary,
)

input = tf.random.uniform([1, 12], 0, 100, dtype=tf.int32) #create dummy input
enc_padding_mask, look_ahead_mask, dec_padding_mask = create_masks(input, input) # create masks
a = model(input, input, enc_padding_mask, look_ahead_mask, dec_padding_mask) # call the model before loading weights

model.load_weights('checkpoints/ckpt-5.data-00000-of-00001')
  

Комментарии:

1. Не могли бы вы объяснить, что переменная «a» делает в этом коде?

2. Спасибо за ваше предложение, я попробую

3. теперь он возвращает новую ошибку: TypeError: в call() отсутствуют 4 обязательных позиционных аргумента: ‘tar’, ‘enc_padding_mask’, ‘look_ahead_mask’ и ‘dec_padding_mask’

4. Андрей, могу я спросить, что означает [8, 24] в этой ситуации?

5. создайте еще один вопрос. Эта ошибка не относится к вашему вопросу