#python #tensorflow2.0 #object-detection-api
#python #tensorflow2.0 #object-detection-api
Вопрос:
Я использую API обнаружения объектов Tensorflow. Недавно он был обновлен до Tensorflow2. И вместе с ним авторы выпустили отличный Colab https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/eager_few_shot_od_training_tf2_colab.ipynb. Они настраивают RetinaNet на новый набор данных, однако я не понимаю, как я могу использовать это для точной настройки CenterNet (и EfficientDet).
У них есть следующий код для инициализации модели RetinaNet:
tf.keras.backend.clear_session()
print('Building model and restoring weights for fine-tuning...', flush=True)
num_classes = 1
pipeline_config = 'models/research/object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config'
checkpoint_path = 'models/research/object_detection/test_data/checkpoint/ckpt-0'
# Load pipeline config and build a detection model.
#
# Since we are working off of a COCO architecture which predicts 90
# class slots by default, we override the `num_classes` field here to be just
# one (for our new rubber ducky class).
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
model_config.ssd.num_classes = num_classes
model_config.ssd.freeze_batchnorm = True
detection_model = model_builder.build(
model_config=model_config, is_training=True)
# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification, the other for box regression. We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
_base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
# _prediction_heads=detection_model._box_predictor._prediction_heads,
# (i.e., the classification head that we *will not* restore)
_box_prediction_head=detection_model._box_predictor._box_prediction_head,
)
fake_model = tf.compat.v2.train.Checkpoint(
_feature_extractor=detection_model._feature_extractor,
_box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()
# Run model through a dummy image so that variables are created
image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)
print('Weights restored!')
Я попытался сделать то же самое с моделью CenterNet (она используется для вывода в этом руководстве по Colab https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/inference_tf2_colab.ipynb):
pipeline_config = 'models/research/object_detection/configs/tf2/centernet_hourglass104_512x512_coco17_tpu-8.config'
model_dir = 'models/research/object_detection/test_data/checkpoint/'
num_classes = 1
# Load pipeline config and build a detection model
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
model_config.center_net.num_classes = num_classes
detection_model = model_builder.build(
model_config=model_config, is_training=True)
# Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(
model=detection_model)
ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()
Однако выдается исключение, поскольку формы несовместимы (потому что я изменил количество классов). В примере с RetinaNet этот трюк использовался (насколько я понимаю) для создания тензоров правильной формы:
fake_box_predictor = tf.compat.v2.train.Checkpoint(
_base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
# _prediction_heads=detection_model._box_predictor._prediction_heads,
# (i.e., the classification head that we *will not* restore)
_box_prediction_head=detection_model._box_predictor._box_prediction_head,
)
fake_model = tf.compat.v2.train.Checkpoint(
_feature_extractor=detection_model._feature_extractor,
_box_predictor=fake_box_predictor)
Но как я могу узнать, что я должен написать внутри функции контрольной точки? (например, _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads
или _box_prediction_head=detection_model._box_predictor._box_prediction_head
)
Комментарии:
1. Не могли бы вы любезно сообщить мне, как вы решили проблему?