Невозможно загрузить данные контрольной точки

#python #tensorflow

#python #тензорный поток

Вопрос:

Я следую коду отсюда, чтобы изучить задачу обобщения текста с помощью модели transformer, ее можно найти здесь

Но код не предоставляет способа загрузки модели после обучения, так что это неудобство, и я решил написать эту функцию

Вот что я сделал:

 model = Transformer(
num_layers, 
d_model, 
num_heads, 
dff,
encoder_vocab_size, 
decoder_vocab_size, 
pe_input=max_len_news,
pe_target=max_len_summary,
)

input = tf.random.uniform([1, 12], 0, 100, dtype=tf.int32) #create dummy input
enc_padding_mask, look_ahead_mask, dec_padding_mask = create_masks(input, input) # create masks
a = model(input, input, False, enc_padding_mask, look_ahead_mask, dec_padding_mask) # call the model before loading weights

model.load_weights('checkpoints/ckpt-5.data-00000-of-00001')
  

Теперь он выдает ошибку:

 /usr/local/lib/python3.6/dist-packages/h5py/_hl/files.py in make_fid(name, mode, userblock_size, fapl, fcpl, swmr)
171         if swmr and swmr_support:
172             flags |= h5f.ACC_SWMR_READ
--> 173         fid = h5f.open(name, flags, fapl=fapl)
174     elif mode == 'r ':
175         fid = h5f.open(name, h5f.ACC_RDWR, fapl=fapl)

h5py/_objects.pyx in h5py._objects.with_phil.wrapper()

h5py/_objects.pyx in h5py._objects.with_phil.wrapper()

h5py/h5f.pyx in h5py.h5f.open()

OSError: Unable to open file (file signature not found)
  

Я совершенно новичок в машинном обучении и TensorFlow. Пожалуйста, помогите.

Ответ №1:

Просматриваю документацию tf.keras.Model.load_weights (курсив мой):

Аргументы

Строка пути к файлу, путь к файлу весов для загрузки. Для файлов с весом в формате TensorFlow это префикс файла (тот же, который был передан в save_weights).

Вам нужно только передать префикс, поэтому

 model.load_weights("checkpoints/ckpt-5")
  

должно сработать.