Вычислить парное расстояние между карнизами пакета в PyTorch

#pytorch

#pytorch

Вопрос:

У меня есть два тензора, и оба имеют одинаковую форму. Я хочу рассчитать попарное расстояние между воронками, используя GeomLoss .

Что я пробовал:

 import torch
import geomloss  # pip install git https://github.com/jeanfeydy/geomloss

a = torch.rand((8,4))
b = torch.rand((8,4))

geomloss.SamplesLoss('sinkhorn')(a,b)
# ^ input shape [batch, feature_dim]
# will return a scalar value

geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(1),b.unsqueeze(1))  
# ^ input shape [batch, n_points, feature_dim]
# will return a tensor of size [batch] of distances between a[i] and b[i] for each i
 

Однако я хотел бы вычислить попарное расстояние, где результирующий тензор должен иметь размер [batch, batch] . Чтобы добиться этого, я попробовал использовать широковещательную передачу следующим образом:

 geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(0), b.unsqueeze(1))
 

Но я получил это сообщение об ошибке:

ValueError: выборки x и y должны иметь одинаковый размер пакета.

Комментарии:

1. Итак, вы хотели бы иметь 2D-тензор расстояний между a[i] и b[j] в каждой паре индексов пакета (i, j) , верно?

2. @Ivan Да, точно

Ответ №1:

Поскольку в документации не приводятся примеры того, как использовать функцию прямой передачи расстояния. Вот способ сделать это, который потребует от вас многократного вызова функции расстояния batch .

Мы построим матрицу расстояний построчно. Строка i соответствует расстояниям a[i]<->b[0] a[i]<->b[1] от , до a[i]<->b[batch] . Для этого нам нужно построить для каждой строки i (8x4) повторяющуюся версию тензора a[i] .

Это будет сделано:

 a_i = torch.stack(8*[a[i]], dim=0)
 

Затем мы вычисляем расстояние между a[i] и каждой партией в b :

 dist(a_i.unsqueeze(1), b.unsqueeze(1))
 

Имея общее количество batch строк, мы можем построить наш конечный тензор stack .


Вот полный код:

 batch = a.shape[0]
dist = geomloss.SamplesLoss('sinkhorn')
distances = [dist(torch.stack(batch*[a[i]]).unsqueeze(1), b.unsqueeze(1)) for i in range(batch)]
D = torch.stack(distances)
 

Комментарии:

1. Спасибо за ваше решение. Я также рассматривал решение на основе циклов в качестве опции. Ценю вашу помощь !.