Взятие подмножества данных, принадлежащих определенным классам

#python #machine-learning #pytorch

Вопрос:

Я пытаюсь взять только два класса, например: кошки и собаки из набора данных CIFAR (который загружается как набор поездов). Я пытаюсь использовать для этого следующий код:

 def getIndices(d_targets, idx):
    lst=[]
    for j in idx:
        for (i,index) in enumerate(d_targets):
            if (index == j):
                lst.append(i)
    return lst    
labels_to_select = [3,5] #cat vs dog
trainset_subset_labels = getIndices(trainset.targets,labels_to_select)
trainset_2 = torch.utils.data.Subset(trainset,trainset_subset_labels)
trainloader = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size,shuffle=True, 
num_workers=2)
 

форма комплекта поездов -> (50000, 32,32,3)

желаемая форма trainset_2 -> (10000,32,32,3)

форма trainset_2, которую я получаю -> (50000,32,32,3)

Я должен был бы получить меньший набор данных в обоих trainset_2, но этого не происходит. Есть идеи, что я делаю не так.

Комментарии:

1. Этот код работает нормально. Проверьте еще раз. len(trainset), len(trainset_2) дает (50000, 10000)