Я попытался разделить resnet на две части, используя pytorch children(), но это не работает

#pytorch

#pytorch

Вопрос:

Вот простой пример. Я попытался разделить сеть (Resnet50) на две части: head и tail с помощью children . Концептуально это должно работать, но это не так. Почему это?

 import torch
import torch.nn as nn
from torchvision.models import resnet50

head = nn.Sequential(*list(resnet.children())[:-2])
tail = nn.Sequential(*list(resnet.children())[-2:])
x = torch.zeros(1, 3, 160, 160)

resnet(x).shape      # torch.Size([1, 1000])
head(x).shape        # torch.Size([1, 2048, 5, 5])
tail(head(x)).shape  # Error: RuntimeError: size mismatch, m1: [2048 x 1], m2: [2048 x 1000] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:136
 

Для информации, хвост — это не что иное, как

 Sequential(
  (0): AdaptiveAvgPool2d(output_size=(1, 1))
  (1): Linear(in_features=2048, out_features=1000, bias=True)
)
 

Так что я действительно знаю, что если я смогу сделать так. Но тогда почему функция изменения формы ( view ) отсутствует в дочерних элементах?

 pool =resnet._modules['avgpool']
fc = resnet._modules['fc']
fc(pool(head(x)).view(1, -1))
 

Ответ №1:

То, что вы хотите сделать, это отделить средство извлечения объектов от классификатора.

  • Что я должен сразу отметить, так это то, что Resnet не является последовательной моделью (как следует из названия — остаточная сеть — это как остатки)!

    Поэтому компиляция его до a nn.Sequential не будет точной. Существует разница между определением модели, в которой упорядочены слои .children() , и фактической базовой реализацией forward функции этой модели.

  • Выравнивание, которое вы выполнили с помощью view(1, -1) , не регистрируется как слой во всех torchvision.models.resnet* моделях. Вместо этого это выполняется в этой строке в forward определении:
     x = torch.flatten(x, 1)
     

    Они могли бы зарегистрировать его как слой в __init__ as self.flatten = nn.Flatten() , который будет использоваться в forward реализации as x = self.flatten(x) .

    Тем не менее, это fc(pool(head(x)).view(1, -1)) полностью отличается от resnet(x) (см. Первый пункт).

Ответ №2:

Добавление nn.Flatten модуля в tail , похоже, решает вашу проблему:

 
import torch
import torch.nn as nn
from torchvision.models import resnet50
resnet = resnet50()
head = nn.Sequential(*list(resnet.children())[:-2])
tail = nn.Sequential(*[list(resnet.children())[-2], nn.Flatten(start_dim=1), list(resnet.children())[-1]])
x = torch.zeros(1, 3, 160, 160)

resnet(x).shape      # torch.Size([1, 1000])
head(x).shape        # torch.Size([1, 2048, 5, 5])
tail(head(x)).shape  # torch.Size([1, 1000])