#pytorch
#pytorch
Вопрос:
У меня есть два тензора, и оба имеют одинаковую форму. Я хочу рассчитать попарное расстояние между воронками, используя GeomLoss
.
Что я пробовал:
import torch
import geomloss # pip install git https://github.com/jeanfeydy/geomloss
a = torch.rand((8,4))
b = torch.rand((8,4))
geomloss.SamplesLoss('sinkhorn')(a,b)
# ^ input shape [batch, feature_dim]
# will return a scalar value
geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(1),b.unsqueeze(1))
# ^ input shape [batch, n_points, feature_dim]
# will return a tensor of size [batch] of distances between a[i] and b[i] for each i
Однако я хотел бы вычислить попарное расстояние, где результирующий тензор должен иметь размер [batch, batch]
. Чтобы добиться этого, я попробовал использовать широковещательную передачу следующим образом:
geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(0), b.unsqueeze(1))
Но я получил это сообщение об ошибке:
ValueError: выборки
x
иy
должны иметь одинаковый размер пакета.
Комментарии:
1. Итак, вы хотели бы иметь 2D-тензор расстояний между
a[i]
иb[j]
в каждой паре индексов пакета(i, j)
, верно?2. @Ivan Да, точно
Ответ №1:
Поскольку в документации не приводятся примеры того, как использовать функцию прямой передачи расстояния. Вот способ сделать это, который потребует от вас многократного вызова функции расстояния batch
.
Мы построим матрицу расстояний построчно. Строка i
соответствует расстояниям a[i]<->b[0]
a[i]<->b[1]
от , до a[i]<->b[batch]
. Для этого нам нужно построить для каждой строки i
(8x4)
повторяющуюся версию тензора a[i]
.
Это будет сделано:
a_i = torch.stack(8*[a[i]], dim=0)
Затем мы вычисляем расстояние между a[i]
и каждой партией в b
:
dist(a_i.unsqueeze(1), b.unsqueeze(1))
Имея общее количество batch
строк, мы можем построить наш конечный тензор stack
.
Вот полный код:
batch = a.shape[0]
dist = geomloss.SamplesLoss('sinkhorn')
distances = [dist(torch.stack(batch*[a[i]]).unsqueeze(1), b.unsqueeze(1)) for i in range(batch)]
D = torch.stack(distances)
Комментарии:
1. Спасибо за ваше решение. Я также рассматривал решение на основе циклов в качестве опции. Ценю вашу помощь !.