Как эффективно рассчитать парное расстояние между пакетами в PyTorch

#deep-learning #pytorch #tensor

#глубокое обучение #pytorch #тензор

Вопрос:

У меня есть тензоры X формы BxNxD и Y формы BxNxD .

Я хочу вычислить попарные расстояния для каждого элемента в пакете, т.е. я BxMxN тензор.

Как мне это сделать?

Здесь есть некоторое обсуждение на эту тему: https://github.com/pytorch/pytorch/issues/9406 , но я этого не понимаю, поскольку существует много деталей реализации, в то время как фактическое решение не выделено.

Наивным подходом было бы использовать ответ для парных расстояний без пакетирования, как обсуждалось здесь:https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065, то есть

 import torch
import numpy as np

B = 32
N = 128
M = 256
D = 3

X = torch.from_numpy(np.random.normal(size=(B, N, D)))
Y = torch.from_numpy(np.random.normal(size=(B, M, D)))


def pairwise_distances(x, y=None):
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)

    dist = x_norm   y_norm - 2.0 * torch.mm(x, y_t)
    return torch.clamp(dist, 0.0, np.inf)


out = []
for b in range(B):
    out.append(pairwise_distances(X[b], Y[b]))
print(torch.stack(out).shape)
  

Как я могу сделать это без зацикливания на B?
Спасибо

Ответ №1:

У меня была похожая проблема, и я потратил некоторое время, чтобы найти самое простое и быстрое решение. Теперь вы можете рассчитать пакетное расстояние с помощью PyTorch cdist, который выдаст вам BxMxN тензор:

 torch.cdist(Y, X)
  

Кроме того, это хорошо работает, если вы просто хотите вычислить расстояния между каждой парой строк из двух матриц.