#python #tensorflow
#python #тензорный поток
Вопрос:
Я следую коду отсюда, чтобы изучить задачу обобщения текста с помощью модели transformer, ее можно найти здесь
Но код не предоставляет способа загрузки модели после обучения, так что это неудобство, и я решил написать эту функцию
Вот мой код:
model = Transformer(
num_layers,
d_model,
num_heads,
dff,
encoder_vocab_size,
decoder_vocab_size,
pe_input=max_len_news,
pe_target=max_len_summary,
)
model.load_weights('checkpoints/ckpt-5.data-00000-of-00001')
Выдает ошибку:
ValueError: Unable to load weights saved in HDF5 format into a subclassed Model which has not created its variables yet. Call the Model first, then load the weights.
Я совершенно новичок в машинном обучении и TensorFlow. Я знаю, что он пытается сказать, но я просто не знаю, как исправить эту проблему, пожалуйста, помогите.
Ответ №1:
Перед загрузкой весов необходимо вызвать модель с фиктивным вводом.
Попробуйте это:
model = Transformer(
num_layers,
d_model,
num_heads,
dff,
encoder_vocab_size,
decoder_vocab_size,
pe_input=max_len_news,
pe_target=max_len_summary,
)
input = tf.random.uniform([1, 12], 0, 100, dtype=tf.int32) #create dummy input
enc_padding_mask, look_ahead_mask, dec_padding_mask = create_masks(input, input) # create masks
a = model(input, input, enc_padding_mask, look_ahead_mask, dec_padding_mask) # call the model before loading weights
model.load_weights('checkpoints/ckpt-5.data-00000-of-00001')
Комментарии:
1. Не могли бы вы объяснить, что переменная «a» делает в этом коде?
2. Спасибо за ваше предложение, я попробую
3. теперь он возвращает новую ошибку: TypeError: в call() отсутствуют 4 обязательных позиционных аргумента: ‘tar’, ‘enc_padding_mask’, ‘look_ahead_mask’ и ‘dec_padding_mask’
4. Андрей, могу я спросить, что означает [8, 24] в этой ситуации?
5. создайте еще один вопрос. Эта ошибка не относится к вашему вопросу