#python-3.x #matrix #pytorch #linear-algebra
#python-3.x #матрица #pytorch #линейная алгебра
Вопрос:
Я пытаюсь умножить следующее:
Пакет матриц N x M x D
Пакет векторов N x D x 1
Чтобы получить результат: N x M x 1
как если бы я делал N
точечные произведения M x D
D x 1
.
Кажется, я не могу найти правильную функцию в PyTorch.
torch.bmm
насколько я могу судить, работает только для пакета векторов и одной матрицы. Если мне нужно использовать torch.einsum
, пусть будет так, но id скорее нет!
Ответ №1:
Это довольно просто и интуитивно понятно с einsum
:
torch.einsum('ijk, ikl->ijl', mats, vecs)
Но ваша операция просто:
mats @ vecs