Попытка добавить предупреждение в модель обнаружения «ssd_mobilenet_v2» выдает ошибку

#python #tensorflow #machine-learning #deep-learning #computer-vision

#питон #tensorflow #машинное обучение #глубокое обучение #компьютерное зрение

Вопрос:

Используя «ssd_mobilenet_v2_fpn_keras», я пытаюсь добавить систему оповещения

Модель обнаружения загружается в нижеприведенную функцию

  def detect_fn(image):
    image, shapes = detection_model.preprocess(image)
    prediction_dict = detection_model.predict(image, shapes)
    detections = detection_model.postprocess(prediction_dict, shapes)
    return detections
 

Изображение преобразуется в тензор

 input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
 

Тензор привязан к модели обнаружения

detections = detect_fn(input_tensor)

Результатом модели обнаружения является словарь со следующими ключами:

 dict_keys(['detection_boxes', 'detection_scores', 'detection_classes', 'raw_detection_boxes', 'raw_detection_scores', 'detection_multiclass_scores', 'detection_anchor_indices', 'num_detections'])
 

detections[detection_classes] , выдает следующий вывод, т.е. 0 — ClassA, 1 — ClassB

 [0 1 1 0 0 1 0 0 1 0 1 1 0 0 1 0 1 1 0 1 0 1 1 0 0 1 0 0 1 0 1 0 0 1 1 1 1 0 0 0 1 1 1 0 0 1 1 1 0 1 0 1 0 0 0 0 1 0 0 1 0 0 1 0 1 0 0 1 0 0 0 0 1 0 1 1 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 1 0 1]
 

detections['detection_scores'] выдает оценку для каждого обнаруженного поля (некоторые показаны ниже)

 [0.988446 0.7998712 0.1579772 0.13801616 0.13227147 0.12731305 0.09515342 0.09203091 0.09191579 0.08860824 0.08313078 0.07684237
 

Я пытаюсь Print("Attention needed") , если наблюдается класс обнаружения B ie 1

 for key in detections['detection_classes']:
if key==1:
    print('Alert')
 

Когда я пытаюсь это сделать, я получаю сообщение об ошибке

 `ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
 

Как заставить это работать?

Я хочу, чтобы код для печати «Требуется внимание» был Class = 1 или A, а detection_sores > = 14

Код объяснен чуть дальше


ссылки для полного кода приведены ниже :

Ответ №1:

Как указано в сообщении об ошибке, вы должны использовать .any() . Нравится:

 if (key == 1).any():
  print('Alert')
 

Как key == 1 будет массив с [False, True, True, False, ...]

Вы также можете захотеть обнаружить те, которые превышают определенный балл, скажем, 0.7:

 for key, score in zip(
  detections['detection_classes'],
  detections['detection_scores']):
  if score > 0.7 and key == 1:
    print('Alert')
    break