Колаборатория обнаружения объектов с несколькими выстрелами для CenterNet

#python #tensorflow2.0 #object-detection-api

#python #tensorflow2.0 #object-detection-api

Вопрос:

Я использую API обнаружения объектов Tensorflow. Недавно он был обновлен до Tensorflow2. И вместе с ним авторы выпустили отличный Colab https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/eager_few_shot_od_training_tf2_colab.ipynb. Они настраивают RetinaNet на новый набор данных, однако я не понимаю, как я могу использовать это для точной настройки CenterNet (и EfficientDet).

У них есть следующий код для инициализации модели RetinaNet:

 tf.keras.backend.clear_session()

print('Building model and restoring weights for fine-tuning...', flush=True)
num_classes = 1
pipeline_config = 'models/research/object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config'
checkpoint_path = 'models/research/object_detection/test_data/checkpoint/ckpt-0'

# Load pipeline config and build a detection model.
#
# Since we are working off of a COCO architecture which predicts 90
# class slots by default, we override the `num_classes` field here to be just
# one (for our new rubber ducky class).
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
model_config.ssd.num_classes = num_classes
model_config.ssd.freeze_batchnorm = True
detection_model = model_builder.build(
      model_config=model_config, is_training=True)

# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification, the other for box regression.  We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    # _prediction_heads=detection_model._box_predictor._prediction_heads,
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

# Run model through a dummy image so that variables are created
image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)
print('Weights restored!')
  

Я попытался сделать то же самое с моделью CenterNet (она используется для вывода в этом руководстве по Colab https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/inference_tf2_colab.ipynb):

 pipeline_config =  'models/research/object_detection/configs/tf2/centernet_hourglass104_512x512_coco17_tpu-8.config'
model_dir = 'models/research/object_detection/test_data/checkpoint/'
num_classes = 1
# Load pipeline config and build a detection model
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']

model_config.center_net.num_classes = num_classes
detection_model = model_builder.build(
      model_config=model_config, is_training=True)

# Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(
      model=detection_model)
ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()
  

Однако выдается исключение, поскольку формы несовместимы (потому что я изменил количество классов). В примере с RetinaNet этот трюк использовался (насколько я понимаю) для создания тензоров правильной формы:

 fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    # _prediction_heads=detection_model._box_predictor._prediction_heads,
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
  

Но как я могу узнать, что я должен написать внутри функции контрольной точки? (например, _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads или _box_prediction_head=detection_model._box_predictor._box_prediction_head )

Комментарии:

1. Не могли бы вы любезно сообщить мне, как вы решили проблему?