AttributeError: ‘Функциональный’ объект не имеет атрибута ‘predict_segmentation’ при импорте Keras модели тензорного потока

#python #tensorflow #keras #deep-learning

#python #тензорный поток #keras #глубокое обучение

Вопрос:

Я успешно обучил модель Keras, такую как:

 import tensorflow as tf
from keras_segmentation.models.unet import vgg_unet

# initaite the model
model = vgg_unet(n_classes=50, input_height=512, input_width=608)

# Train
model.train(
    train_images=train_images,
    train_annotations=train_annotations,
    checkpoints_path="/tmp/vgg_unet_1", epochs=5
)
 

И сохранил его в формате hdf5 с:

 tf.keras.models.save_model(model,'my_model.hdf5')
 

Затем я загружаю свою модель с

 model=tf.keras.models.load_model('my_model.hdf5')
 

Наконец, я хочу сделать прогноз сегментации для нового изображения с помощью

 out = model.predict_segmentation(
    inp=image_to_test,
    out_fname="/tmp/out.png"
)
 

Я получаю следующую ошибку:

 AttributeError: 'Functional' object has no attribute 'predict_segmentation'
 

Что я делаю не так ?
Это происходит, когда я сохраняю свою модель или когда я ее загружаю?

Спасибо!

Комментарии:

1. Откуда вы знаете, что этот метод predict_segmentation действительно существует?

2. Ну, он существует до того, как я сохраняю модель, поскольку я могу сделать хороший прогноз сегментации с тем же кодом.

3. Конечно, но это не стандартный метод моделей Keras.

Ответ №1:

predict_segmentation эта функция недоступна в обычных моделях Keras. Похоже, что он был добавлен после того, как модель была создана в keras_segmentation библиотеке, и, возможно, именно поэтому Keras не смог загрузить ее снова.

Я думаю, у вас есть 2 варианта для этого.

  1. Вы можете использовать строку из кода, который я связал, чтобы вручную добавить функцию обратно в модель.
 model.predict_segmentation = MethodType(keras_segmentation.predict.predict, model)
 
  1. Вы можете создать новый vgg_unet с теми же аргументами при перезагрузке модели и перенести веса из вашего hdf5 файла в эту модель, как предложено в документации Keras.
 model = vgg_unet(n_classes=50, input_height=512, input_width=608)
model.load_weights('my_model.hdf5')