умножение многих матриц и многих векторов pytorch

#python-3.x #matrix #pytorch #linear-algebra

#python-3.x #матрица #pytorch #линейная алгебра

Вопрос:

Я пытаюсь умножить следующее:

Пакет матриц N x M x D
Пакет векторов N x D x 1
Чтобы получить результат: N x M x 1

как если бы я делал N точечные произведения M x D D x 1 .

Кажется, я не могу найти правильную функцию в PyTorch.

torch.bmm насколько я могу судить, работает только для пакета векторов и одной матрицы. Если мне нужно использовать torch.einsum , пусть будет так, но id скорее нет!

Ответ №1:

Это довольно просто и интуитивно понятно с einsum :

 torch.einsum('ijk, ikl->ijl', mats, vecs)
  

Но ваша операция просто:

 mats @ vecs