Получение меток классов моделей из предварительно подготовленных моделей torchvision

#python #pytorch #torchvision

#python #pytorch #torchvision

Вопрос:

Я использую предварительно обученную модель Alexnet (без точной настройки) из torchvision. Проблема в том, что, хотя я могу запустить модель на некоторых данных и получить распределение вероятностей на выходе, я не могу найти метки классов для ее сопоставления.

Следуя этой официальной документации

 import torch
model = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True)
model.eval()
  
 AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
  

Выполнив несколько шагов по обработке изображения, я могу использовать его для получения выходных данных для одного изображения в виде (1,1000) вектора dim, для которого я буду использовать softmax для получения распределения вероятностей —

 #Output - 

tensor([-1.6531e 00, -4.3505e 00, -1.8172e 00, -4.2143e 00, -3.1914e 00,
         3.4163e-01,  1.0877e 00,  5.9350e 00,  8.0425e 00, -7.0242e-01,
        -9.4130e-01, -6.0822e-01, -2.4097e-01, -1.9946e 00, -1.5288e 00,
        -3.2656e 00, -5.5800e-01,  1.0524e 00,  1.9211e-01, -4.7202e 00,
        -3.3880e 00,  4.3048e 00, -1.0997e 00,  4.6132e 00, -5.7404e-03,
        -5.3437e 00, -4.7378e 00, -3.3974e 00, -4.1287e 00,  2.9064e-01,
        -3.2955e 00, -6.7051e 00, -4.7232e 00, -4.1778e 00, -2.1859e 00,
        -2.9469e 00,  3.0465e 00, -3.5882e 00, -6.3890e 00, -4.4203e 00,
        -3.3685e 00, -5.0983e 00, -4.9006e 00, -5.5235e 00, -3.7233e 00,
        -4.0204e 00,  2.6998e-01, -4.4702e 00, -5.6617e 00, -5.4880e 00,
        -2.6801e 00, -3.2129e 00, -1.6294e 00, -5.2289e 00, -2.7495e 00,
        -2.6286e 00, -1.8206e 00, -2.3196e 00, -5.2806e 00, -3.7652e 00,
        -3.0987e 00, -4.1421e 00, -5.2531e 00, -4.6505e 00, -3.5815e 00,
        -4.0189e 00, -4.0008e 00, -4.5512e 00, -3.2248e 00, -7.7903e 00,
        -1.4484e 00, -3.8347e 00, -4.5611e 00, -4.3681e 00,  2.7234e-01,
        -4.0162e 00, -4.2136e 00, -5.4524e 00,  1.1744e 00, -4.7785e 00,
        -1.8335e 00,  4.1288e-01,  2.2239e 00, -9.9919e-02,  4.8216e 00,
        -8.4304e-01,  5.6911e-01, -4.0484e 00, -3.3013e 00,  2.8698e 00,
        -1.1419e 00, -9.1690e-01, -2.9284e 00, -2.6097e 00, -1.8213e-01,
        -2.5429e 00, -2.1095e 00,  2.2419e 00, -1.6280e 00,  7.4458e 00,
         2.3184e 00, -5.7408e 00, -7.4332e-01, -5.4066e 00,  1.5177e 01,
        -4.4737e-02,  1.8237e 00, -3.7741e 00,  9.2271e-01, -4.3687e-01,
        -1.4003e 00, -4.3026e 00,  6.3782e-01, -1.0808e 00, -1.4173e 00,
         2.6194e 00, -3.8418e 00,  1.1598e 00, -2.6876e 00, -3.6103e 00,
        -4.9281e 00, -4.1411e 00, -3.3603e 00, -3.4296e 00, -1.4997e 00,
        -2.8381e 00, -1.2843e 00,  1.5745e 00, -1.7449e 00,  4.2903e-01,
         3.1234e-01, -2.8206e 00,  3.6688e-01, -2.1033e 00,  1.6481e 00,
         1.4222e 00, -2.7303e 00, -3.6292e 00,  1.2864e 00, -2.5541e 00,
        -2.9663e 00, -4.1575e 00, -3.1954e 00, -4.6487e-01,  1.8916e 00,
        -7.4721e-01,  4.5986e 00, -2.5443e 00, -6.2003e 00, -1.3215e 00,
        -2.6225e 00,  9.9639e 00,  9.7772e 00,  9.6715e 00,  9.0857e 00,...
  

Откуда мне взять метки классов? Я не смог найти ни одного метода, который позволил бы мне получить это из объекта модели.

Комментарии:

1. Модель не содержит меток классов, последние слои просто выводят логиты, равные no of classes. Вместо этого загрузчик данных содержит метки классов.

2. Но, очевидно, проблема не в этом, я всего лишь извлекаю предварительно обученную модель, а не какой-либо загрузчик данных, поскольку я не переподготовляю исходные данные или не настраиваю их с помощью своих собственных меток данных и классов. Пожалуйста, проверьте официальную ссылку на документацию выше.

3. Например, Sklearn хранит метки классов в объекте model model.classes_ , чтобы их можно было получить, просто загрузив обученную модель, не беспокоясь о загрузчике данных.

Ответ №1:

К сожалению, вы не можете получить имена меток классов непосредственно из моделей torchvision. Однако эти модели обучаются на наборе данных ImageNet (отсюда и 1000 классов).

Вы должны получить сопоставление имен классов из Интернета, насколько я знаю; нет способа получить его из torch. Ранее вы могли загружать ImageNet напрямую с помощью torchvision.datasets.ImageNet, в котором был встроенный конвертер меток в имена классов. Теперь ссылка для скачивания недоступна для общественности и требует ручной загрузки, прежде чем ее можно будет использовать datasets.ImageNet.

Таким образом, вы можете просто выполнить поиск класса для сопоставления меток в ImageNet в Интернете, вместо того, чтобы загружать данные или пытаться использовать torch. Попробуйте здесь, например.