Попытка получить доступ к подмножеству набора данных mnist в pytorch [равные выборки из каждого класса]

#python #arrays #numpy #pytorch #mnist

#python #массивы #numpy #pytorch #mnist

Вопрос:

Пытаюсь получить доступ к подмножеству набора данных mnist в pytorch [равные выборки из каждого класса], но получаю эту ошибку

 prng = RandomState(42)
random_permute = prng.permutation(np.arange(0, 6000))[0:3000]
indx = np.concatenate([np.where(np.array(mnist_data.targets) == classe)[0][random_permute] for classe in range(0,10)])
  
 ---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-178-038015f76b77> in <module>
----> 1 indx = np.concatenate([np.where(np.array(mnist_data.targets) == classe)[0][random_permute] for classe in range(0,10)])

<ipython-input-178-038015f76b77> in <listcomp>(.0)
----> 1 indx = np.concatenate([np.where(np.array(mnist_data.targets) == classe)[0][random_permute] for classe in range(0,10)])

IndexError: index 5992 is out of bounds for axis 0 with size 5923
  

Ответ №1:

Набор данных MNIST не имеет равномерного распределения целей. Вы получаете эту ошибку, потому что класс 0 в MNIST содержит 5923 выборки.

 nums = [0]*10
for i in range(60000):
  nums[(int(mnist_data.targets[i]))]  = 1
print(nums)
  

Это будет напечатано [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949] .