#python #pytorch #resnet #torchvision #faster-rcnn
#python #pytorch #resnet #torchvision #быстрее -rcnn
Вопрос:
Я копаюсь в исходном коде более быстрой реализации R-CNN torchvision
и сталкиваюсь с некоторыми вещами, которые я не совсем понимаю. А именно, предполагая, что я хочу создать более быструю модель R-CNN, не предварительно обученную на COCO, с магистралью, предварительно обученной в ImageNet, а затем просто получить магистраль, я делаю следующее:
plain_backbone = fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=True).backbone.body
Что согласуется с тем, как настроена магистраль, как указано здесь и здесь . Однако, когда я пропускаю изображение через модель, результаты не соответствуют тому, что я получил бы, если бы я просто настроил resnet50
напрямую. А именно:
# Regular resnet50, pretrained on ImageNet, without the classifier and the average pooling layer
resnet50_1 = torch.nn.Sequential(*(list(torchvision.models.resnet50(pretrained=True).children())[:-2]))
resnet50_1.eval()
# Resnet50, extract from the Faster R-CNN, also pre-trained on ImageNet
resnet50_2 = fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=True).backbone.body
resnet50_2.eval()
# Loading a random image, converted to torch.Tensor, rescalled to [0, 1] (not that it matters)
image = transforms.ToTensor()(Image.open("random_images/random.jpg")).unsqueeze(0)
# Obtaining the model outputs
with torch.no_grad():
# Output from the regular resnet50
output_1 = resnet50_1(image)
# Output from the resnet50 extracted from the Faster R-CNN
output_2 = resnet50_2(image)["3"]
# Their outputs aren't the same, which I would assume they should be
np.testing.assert_almost_equal(output_1.numpy(), output_2.numpy())
С нетерпением ждем ваших мыслей!
Комментарии:
1. Я тоже это проверил! Кажется, что оба загружают веса с одной и той же контрольной точки, но отличаются по результату.
IntermediateLayerGetter
Класс, который оборачивает магистраль,resnet50_2
может быть ответственным за это, хотя мне еще предстоит исследовать больше.2. Да, это то, что меня смутило. Это
IntermediateLayerGetter
оболочка для простого получения выходных данных из слоев на основе того, что я понял. Тем не менее, дайте мне знать, что вы найдете 🙂
Ответ №1:
Это связано с тем, что fasterrcnn_resnet50_fpn
используется пользовательский уровень нормализации ( FrozenBatchNorm2d
) вместо значения по умолчанию BatchNorm2D
. Они очень похожи, но я подозреваю, что небольшие числовые различия вызывают проблемы.
Проверка пройдет, если вы укажете тот же уровень нормализации, который будет использоваться для стандартной повторной сети:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
import numpy as np
from torchvision.ops import misc as misc_nn_ops
# Regular resnet50, pretrained on ImageNet, without the classifier and the average pooling layer
resnet50_1 = torch.nn.Sequential(*(list(torchvision.models.resnet50(pretrained=True, norm_layer=misc_nn_ops.FrozenBatchNorm2d).children())[:-2]))
resnet50_1.eval()
# Resnet50, extract from the Faster R-CNN, also pre-trained on ImageNet
resnet50_2 = fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=True).backbone.body
resnet50_2.eval()
# am too lazy to get a real image
image = torch.ones((1, 3, 224, 224))
# Obtaining the model outputs
with torch.no_grad():
# Output from the regular resnet50
output_1 = resnet50_1(image)
# Output from the resnet50 extracted from the Faster R-CNN
output_2 = resnet50_2(image)["3"]
# Passes
np.testing.assert_almost_equal(output_1.numpy(), output_2.numpy())
Комментарии:
1. Хороший улов. Для записи, можете ли вы подробнее рассказать о разнице между BatchNorm и FrozenBatchNorm? Кстати, я принимаю ваш ответ сейчас.
2. @gorjan FrozenBatchNorm реализован здесь с помощью чистого PyTorch, в то время как BatchNorm реализован на C . Я думаю, что единственная причина, по которой существует FrozenBatchNorm, заключается в том, что они хотят, чтобы BN оставался в
eval
режиме и не обновлял его параметры с минимальной работой, требуемой от пользователя. Любые различия в выходных данных должны быть только числовыми и (я полагаю) незначительными.3. Я только что нашел официальное объяснение здесь .