Существуют ли предварительные условия для сохранения модели в keras?

#keras #model #save

#keras #Модель #Сохранить

Вопрос:

Я хочу сохранить свою модель keras после обучения. Модель.функция подгонки работает, но, к сожалению, команды model.save(‘path’) или model.save_weights(‘path’) не сработали.

Я также пытался сохранить модель с помощью pickle или np.save, но это тоже не сработало.

Моя модель построена следующим образом:

 model_resnet = Model(inputs=RESNET.input, outputs=RESNET.output)

model = Sequential()
model.add(model_resnet)
model.add(BatchNormalization())
model.add(Reshape((1,256)))

model.add(Bidirectional(GRU(512,return_sequences=True)))
model.add(Bidirectional(GRU(512)))

model.add(Dense(11,activation='softmax'))
  

где RESNET — это 3D-модель resnet32, определенная с помощью keras functional API.
Тот же код может быть написан подобным образом:

 model_ = Sequential()
model_.add(BatchNormalization())
model_.add(Reshape((1,256)))

model_.add(Bidirectional(GRU(512,return_sequences=True)))
model_.add(Bidirectional(GRU(512)))

model_.add(Dense(11,activation='softmax'))

model = Model(input = RESNET.input, outputs = model_(RESNET.output))
  

я пытаюсь сохранить с помощью следующего кода:

 model.save(root_dir '\models\model.h5')
  

и я также пытался:

 x = model.get_weights()
with open(root_dir '\models\model.pickle', 'wb') as f:
    pickle.dump(x, f)
  

Ни один из этих методов не работает.

при использовании функции сохранения keras у меня возникла следующая ошибка: (не обращайте внимания на название модели в ошибке)

   File ".../train.py", line 110, in <module>
    model_video.save(root_dir '\models\model_video.h5')
  File "...anaconda3envstensorflow_envlibsite-packageskerasenginenetwork.py", line 1090, in save
    save_model(self, filepath, overwrite, include_optimizer)
  File "...anaconda3envstensorflow_envlibsite-packageskerasenginesaving.py", line 382, in save_model
    _serialize_model(model, f, include_optimizer)
  File "...anaconda3envstensorflow_envlibsite-packageskerasenginesaving.py", line 114, in _serialize_model
    layer_group[name] = val
  File "...anaconda3envstensorflow_envlibsite-packageskerasutilsio_utils.py", line 218, in __setitem__
    dataset = self.data.create_dataset(attr, val.shape, dtype=val.dtype)
  File "...anaconda3envstensorflow_envlibsite-packagesh5py_hlgroup.py", line 136, in create_dataset
    dsid = dataset.make_new_dset(self, shape, dtype, data, **kwds)
  File "...anaconda3envstensorflow_envlibsite-packagesh5py_hldataset.py", line 117, in make_new_dset
    dtype = numpy.dtype(dtype)
TypeError: data type not understood
  

при использовании pickle у меня возникает следующая ошибка:

 Traceback (most recent call last):
  File ".../train.py", line 113, in <module>
    pickle.dump(x, f)
_pickle.PicklingError: Can't pickle <class 'numpy.ndarray'>: it's not the same object as numpy.ndarray
  

Комментарии:

1. Я думаю, что обе ошибки показывают, что ваша установка numpy каким-то образом нарушена

2. действительно, «conda install numpy» решила проблему.. Спасибо!

Ответ №1:

«conda install numpy» решила проблему.