Пытаюсь напечатать имена классов для породы собак, но он продолжает говорить, что индекс списка находится вне диапазона

#python #deep-learning #pytorch

#python #глубокое обучение #pytorch

Вопрос:

Я использую модель resnet для классификации пород собак, но когда я пытаюсь распечатать изображение с меткой породы собаки, оно говорит, что индекс списка находится вне диапазона. Вот мой код:

 import torchvision.models as models
import torch.nn as nn


model_transfer = models.resnet18(pretrained=True)

if use_cuda:
    model_transfer = model_transfer.cuda()

model_transfer.fc.out_features = 133
  

Затем я тренирую модель и получаю более 70% точности для пород собак.

Тогда вот мой код для классификации dog и печати породы собак:

 data_transfer = {'train': 
 datasets.ImageFolder('/data/dog_images/train',transform=transforms.Compose([transforms.RandomResizedCrop(224),transforms.ToTensor()]))}
class_names[0]
class_names = [item[4:].replace("_", " ") for item in data_transfer['train'].classes]

def predict_breed_transfer(img_path):

    image = Image.open(img_path)

    # large images will slow down processing


    in_transform = transforms.Compose([
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])])

    # discard the transparent, alpha channel (that's the :3) and add the batch dimension
    image = in_transform(image)[:3,:,:].unsqueeze(0)

    image = image

    output = model_transfer(image)
    pred = torch.argmax(output)

    return class_names[pred]
    predict_breed_transfer('images/Labrador_retriever_06455.jpg')
  

По какой-то причине код всегда неправильно предсказывает собаку
Затем, когда я пытаюсь распечатать изображение и метку:

 import matplotlib.pyplot as plt
def run_app(img_path):
    img = Image.open(img_path)
    dog = dog_detector(img_path)
    if not dog: 
        print('hello, human!')
        plt.imshow(img)
        print('You look like a ... ')
        print(predict_breed_transfer(img_path))
    if dog: 
        print('hello, dog!')
        print('Your predicted breed is ....')
        print(predict_breed_transfer(img_path))
        plt.imshow(img)
    else: 
        print('Niether human nor dog')
  

И запустите цикл for, который вызывает его для некоторых изображений собак, он выведет некоторые породы, затем скажет, что индекс списка находится вне диапазона, и не покажет ни одно из изображений.

Длина class_names равна 133, и когда я распечатываю модель resnet, на выходе получается всего 133 узла. кто-нибудь знает, почему он говорит, что индекс списка находится вне диапазона или почему это так неточно.

 `IndexError                                Traceback (most recent 
call last)
<ipython-input-26-473a9ba884b5> in <module>()
      5 ## suggested code, below
      6 for file in np.hstack((human_files[:3], dog_files[:3])):
----> 7     run_app(file)
      8 
 <ipython-input-25-1d44200e44cc> in run_app(img_path)
      10         plt.show(img)
      11         print('You look like a ... ')
 ---> 12         print(predict_breed_transfer(img_path))
      13     if dog:
      14         print('hello, dog!')

 <ipython-input-20-a51fb205659e> in predict_breed_transfer(img_path)
      26     pred = torch.argmax(output)
      27 
 ---> 28     return class_names[pred]
      29 
predict_breed_transfer('images/Labrador_retriever_06455.jpg')
      30 

IndexError: list index out of range`
  

Вот полная ошибка

Комментарии:

1. вы пробовали отлаживать код? Я бы посоветовал вам добавить инструкции отладки и посмотреть, где происходит сбой кода. в вашем коде есть несколько мест, где вы обращаетесь к элементам через индексы массива, это первое место, где я бы написал свое заявление об отладке

2. В какой строке возникает ваша ошибка? Требуется полная трассировка ошибок.

3. Я поместил несколько инструкций print в цикл for. Для первых четырех изображений он выдает число меньше 133, но затем случайным образом просто выдает число, превышающее.

4. Я тоже каждый раз распечатывал функции out, и каждый раз там написано 133, но почему-то все равно выдает мне число поверх этого.

5. Я думаю, вы почти на месте. Проверьте форму вашего output и значение pred в pred = torch.argmax(output) .

Ответ №1:

Я полагаю, у вас есть несколько проблем, которые можно исправить с помощью 13 символов.

Во-первых, я предлагаю то, что предложил @Alekhya Vemavarapu — запустите свой код с помощью отладчика, чтобы изолировать каждую строку и проверить вывод. Это одно из величайших преимуществ динамических графиков с pytorch.

Во-вторых, наиболее вероятной причиной вашей проблемы является argmax оператор, который вы используете неправильно. Вы не указываете размер, для которого вы выполняете argmax , и поэтому PyTorch автоматически выравнивает изображение и выполняет операцию с вектором полной длины. Таким образом, вы получаете число между 0 и MB_Size x num_classes -1 . Смотрите Официальный документ об этом методе.

Итак, из-за вашего полностью подключенного слоя я предполагаю, что ваш вывод имеет форму (MB_Size, num_classes) . Если это так, вам нужно изменить свой код на следующую строку:

 pred = torch.argmax(output,dim=1)
  

и это все. В противном случае просто выберите размерность логитов.

Третья вещь, которую вы хотите учитывать, — это отсев и другие влияния, которые конфигурация обучения может вызвать на вывод. Например, выпадение в некоторых фреймворках может потребовать умножения вывода на 1/(1-p) в выводе (или нет, поскольку это может быть сделано во время обучения), нормализация пакета может быть отменена, поскольку размер пакета отличается, и так далее. Кроме того, чтобы уменьшить потребление памяти, не следует вычислять градиенты. К счастью, разработчики PyTorch очень внимательны и предоставили нам torch.no_grad() и model.eval() для этого.

Я настоятельно рекомендую попрактиковаться в этом, возможно, изменив ваш код несколькими буквами:

 output = model_transfer.eval()(image)
  

и готово!

Редактировать:
Это простой пример неправильного использования фреймворка PyTorch, не читающего документы и не отлаживающего ваш код. Следующий код абсолютно неверен:

 model_transfer.fc.out_features = 133
  

Эта строка фактически не создает новый полностью подключенный слой. Это просто изменяет свойство этого тензора. Попробуйте в своей консоли:

 import torch
a = torch.nn.Linear(1,2)
a.out_features = 3
print(a.bias.data.shape, a.weight.data.shape)
  

Вывод:

 torch.Size([2]) torch.Size([2, 1])
  

что указывает на то, что фактическая матрица весов и вектор смещений остаются в их исходном измерении.
Правильный способ выполнить обучение передаче — сохранить магистраль (обычно сверточные слои до тех пор, пока в моделях такого типа не появятся полностью подключенные) и перезаписать head (в данном случае слой FC) вашим. Если в исходной модели существует только один полностью связанный слой, вам не нужно изменять прямой проход вашей модели, и все готово.
Поскольку этот ответ уже достаточно длинный, просто посетите учебное пособие по переносу в PyTorch docs, чтобы увидеть, как это можно сделать.

Удачи вам.

Комментарии:

1. Я пробовал это, но он по-прежнему говорит, что индекс списка находится вне диапазона.

2. Можете ли вы предоставить более подробную информацию о том, что у вас есть? Например, что такое output.shape ? каков прогнозируемый номер класса?

3. Прогнозируемый номер класса равен 178, а форма вывода равна 1000, хотя выходных узлов должно быть 133

4. Итак, вот ваша проблема. Ваш model_transfer.fc.out_features = 133 фактически не создает новый полностью подключенный слой. Это просто изменяет свойство этого тензора. Попробуйте в своей консоли: import torch; a = torch.nn.Linear(1,2); a.out_features = 3; print(a.bias.data.shape, a.weights.data.shape) и вы увидите, что, хотя вы «изменили» out_features , фактические значения остаются. Это просто неправильный способ управления этими объектами, и мои глаза пропустили эту строку. Я добавлю это к ответу. Пожалуйста, одобрите это, поскольку я вижу, что вы были достаточно невежественны, чтобы каждый раз пинг-понговать меня подсказками.

5. Спасибо, mr_mo за ответ, мне жаль, что это была такая серьезная ошибка, я вроде как новичок в этом.