#c #pytorch #libtorch
#c #pytorch #libtorch
Вопрос:
Когда я пытаюсь сгенерировать список переставленных целочисленных индексов с randperm
помощью C PyTorch API, результирующий тензор имеет тип элемента CPUFloatType{10}
вместо целочисленного типа:
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES);
cout << shuffled_indices << endl;
ВОЗВРАТ
9
3
8
6
2
5
4
7
1
0
[ CPUFloatType{10} ]
Который нельзя использовать для индексации тензоров, потому что тип элемента — float, а не целочисленный тип. При попытке использовать my_tensor.index(shuffled_indices)
я получаю
terminate called after throwing an instance of 'c10::IndexError'
what(): tensors used as indices must be long, byte or bool tensors
Окружающая среда:
- python-pytorch, версия 1.6.0-2 в Arch Linux
- g (GCC) 10.1.0
Почему это происходит?
Ответ №1:
Это потому, что тип по умолчанию любого тензора, который вы создаете с помощью torch, всегда float
. Если вы хотите иначе, вы должны указать это с TensorOptions
параметром struct :
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES, torch::TensorOptions().dtype(at::kLong));
cout << shuffled_indices.dtype() << endl;
>>> long