#pytorch #libtorch
#pytorch #libtorch
Вопрос:
Почему torch::Tensor::is_same
не выполняется следующее утверждение? Тензор записывается в файл с использованием C PyTorch API, затем снова считывается в другой тензор и is_same
сравнивает оба тензора:
torch::Tensor x_sequence = torch::linspace(0, M_PI, 1000);
torch::save(x_sequence, "x_sequence.dat");
torch::Tensor x_read;
torch::load(x_read, "x_sequence.dat");
assert(x_read.is_same(x_sequence));
Это приводит к:
int main(int, char**): Assertion `x_read.is_same(x_sequence)' failed.
используя
- python-pytorch, версия 1.6.0-2 в Arch Linux
- & (GCC) 10.1.0
Ответ №1:
torch::Tensor::is_same(const torch::Tensoramp; other)
определяется здесь. Важно заметить, что a Tensor
на самом деле является указателем на базовый TensorImpl
класс (который фактически содержит данные).
Таким образом, при вызове is_same
проверяется, действительно ли ваши указатели совпадают, т. е. указывают ли ваши 2 тензора на одну и ту же базовую память. Вот очень простой пример, чтобы хорошо понять это :
auto x = torch::randn({4,4});
auto copy = x;
auto clone = x.clone();
std::cout << x.is_same(copy) << " " << x.is_same(clone) << std::endl;
&&t;&&t;&&t; 0 1
Здесь вызов clone
заставляет pytorch скопировать данные в другую ячейку памяти. Следовательно, указатели разные и is_same
возвращает false.
Если вы хотите на самом деле сравнить значения, у вас нет выбора, кроме как вычислить разницу между двумя тензорами и вычислить, насколько эта разница близка к 0.