#django #tensorflow #flask
Вопрос:
Мне было поручено интегрировать чью-то модель в веб-приложение. Они использовали библиотеку под названием Mask-RCNN. Я получаю некоторые ошибки (ниже) при попадании в конечную точку для этой модели. Я видел различные решения для более старых версий TensorFlow, но эти решения не существуют в 2.x или 2.5.
Минимальное повторение с колбой ниже. Та же проблема в Джанго.
# app.py
# run with: flask run
from flask import Flask
app = Flask(__name__)
import tensorflow as tf
import numpy as np
from mrcnn import model as modellib
from mrcnn.config import Config
import cv2
class BaseConfig(Config):
# give the configuration a recognizable name
NAME = "IMGs"
# set the number of GPUs to use training along with the number of
# images per GPU (which may have to be tuned depending on how
# much memory your GPU has)
GPU_COUNT = 1
IMAGES_PER_GPU = 2
# number of classes ( 1 for the background)
NUM_CLASSES = 1 1
# Most objects possible in an image
TRAIN_ROIS_PER_IMAGE = 100
class InferenceConfig(BaseConfig):
GPU_COUNT = 1
IMAGES_PER_GPU = 1
# Most objects possible in an image
DETECTION_MAX_INSTANCES = 200
DETECTION_MIN_CONFIDENCE = 0.7
def load_model_for_inference(weights_path):
"""Initialize a Mask R-CNN model with our InferenceConfig and the specified weights"""
inference_config = InferenceConfig()
model = modellib.MaskRCNN(mode="inference", config=inference_config, model_dir=".")
model.load_weights(weights_path, by_name=True)
# model.keras_model._make_predict_function()
return model
detection_model = load_model_for_inference("mask_rcnn_0427.h5")
@app.route('/detect')
def ic_model_endpoint():
image = cv2.imread('img.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
results = detection_model.detect([image], verbose=1)
print(results)
@app.route("/")
def hello_world():
return "<p>Hello, World!</p>"
Используя вилку на https://github.com/sabderra/Mask_RCNN, но вы не можете публиковать там вопросы, поэтому я решил, что могу спросить здесь. Это вилка Mask RCNN, которая была обновлена для TF 2, но я думаю, может быть, есть еще одно обновление, необходимое для 2.5? Например, нужно ли для этого создавать функцию?
Попадание в /detect
конечную точку дает следующее.
ValueError: Tensor Tensor("mrcnn_detection/Reshape_1:0", shape=(1, 200, 6), dtype=float32) is not an element of this graph.
Если мы раскомментируем model.keras_model._make_predict_function()
, я получу следующее.
tensorflow.python.framework.errors_impl.InvalidArgumentError: Tensor input_image:0, specified in either feed_devices or fetch_devices was not found in the Graph
Теперь я видел много решений, связанных с сеансом, но в этой версии больше нет сеансов. Кроме того, многое говорит об использовании tf.get_default_graph()
, но этого вызова также не существует. Здесь я кое-чего не понимаю в TensorFlow 2 и в том, как взаимодействуют асинхронные функции.
Любое руководство приветствуется!