#pytorch
#pytorch
Вопрос:
Вот простой пример. Я попытался разделить сеть (Resnet50) на две части: head
и tail
с помощью children
. Концептуально это должно работать, но это не так. Почему это?
import torch
import torch.nn as nn
from torchvision.models import resnet50
head = nn.Sequential(*list(resnet.children())[:-2])
tail = nn.Sequential(*list(resnet.children())[-2:])
x = torch.zeros(1, 3, 160, 160)
resnet(x).shape # torch.Size([1, 1000])
head(x).shape # torch.Size([1, 2048, 5, 5])
tail(head(x)).shape # Error: RuntimeError: size mismatch, m1: [2048 x 1], m2: [2048 x 1000] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:136
Для информации, хвост — это не что иное, как
Sequential(
(0): AdaptiveAvgPool2d(output_size=(1, 1))
(1): Linear(in_features=2048, out_features=1000, bias=True)
)
Так что я действительно знаю, что если я смогу сделать так. Но тогда почему функция изменения формы ( view
) отсутствует в дочерних элементах?
pool =resnet._modules['avgpool']
fc = resnet._modules['fc']
fc(pool(head(x)).view(1, -1))
Ответ №1:
То, что вы хотите сделать, это отделить средство извлечения объектов от классификатора.
- Что я должен сразу отметить, так это то, что Resnet не является последовательной моделью (как следует из названия — остаточная сеть — это как остатки)!
Поэтому компиляция его до a
nn.Sequential
не будет точной. Существует разница между определением модели, в которой упорядочены слои.children()
, и фактической базовой реализациейforward
функции этой модели.
- Выравнивание, которое вы выполнили с помощью
view(1, -1)
, не регистрируется как слой во всехtorchvision.models.resnet*
моделях. Вместо этого это выполняется в этой строке вforward
определении:x = torch.flatten(x, 1)
Они могли бы зарегистрировать его как слой в
__init__
asself.flatten = nn.Flatten()
, который будет использоваться вforward
реализации asx = self.flatten(x)
.Тем не менее, это
fc(pool(head(x)).view(1, -1))
полностью отличается отresnet(x)
(см. Первый пункт).
Ответ №2:
Добавление nn.Flatten
модуля в tail
, похоже, решает вашу проблему:
import torch
import torch.nn as nn
from torchvision.models import resnet50
resnet = resnet50()
head = nn.Sequential(*list(resnet.children())[:-2])
tail = nn.Sequential(*[list(resnet.children())[-2], nn.Flatten(start_dim=1), list(resnet.children())[-1]])
x = torch.zeros(1, 3, 160, 160)
resnet(x).shape # torch.Size([1, 1000])
head(x).shape # torch.Size([1, 2048, 5, 5])
tail(head(x)).shape # torch.Size([1, 1000])