Почему `is_same` в C PyTorch API завершается ошибкой при сравнении с тем же тензором, который считывается из файла?

#pytorch #libtorch

#pytorch #libtorch

Вопрос:

Почему torch::Tensor::is_same не выполняется следующее утверждение? Тензор записывается в файл с использованием C PyTorch API, затем снова считывается в другой тензор и is_same сравнивает оба тензора:

 torch::Tensor x_sequence = torch::linspace(0, M_PI, 1000);    
torch::save(x_sequence, "x_sequence.dat");
torch::Tensor x_read;
torch::load(x_read, "x_sequence.dat");
assert(x_read.is_same(x_sequence));  
  

Это приводит к:

 int main(int, char**): Assertion `x_read.is_same(x_sequence)' failed.
  

используя

  • python-pytorch, версия 1.6.0-2 в Arch Linux
  • & (GCC) 10.1.0

Ответ №1:

torch::Tensor::is_same(const torch::Tensoramp; other) определяется здесь. Важно заметить, что a Tensor на самом деле является указателем на базовый TensorImpl класс (который фактически содержит данные).

Таким образом, при вызове is_same проверяется, действительно ли ваши указатели совпадают, т. е. указывают ли ваши 2 тензора на одну и ту же базовую память. Вот очень простой пример, чтобы хорошо понять это :

 auto x = torch::randn({4,4});
auto copy = x;
auto clone = x.clone();
std::cout << x.is_same(copy) << " " << x.is_same(clone) << std::endl;
&&t;&&t;&&t; 0 1
  

Здесь вызов clone заставляет pytorch скопировать данные в другую ячейку памяти. Следовательно, указатели разные и is_same возвращает false.

Если вы хотите на самом деле сравнить значения, у вас нет выбора, кроме как вычислить разницу между двумя тензорами и вычислить, насколько эта разница близка к 0.