#pytorch #matrix-multiplication #complex-numbers
#pytorch #матрица-умножение #комплексные числа
Вопрос:
Я пытаюсь умножить две сложные матрицы в PyTorch, и, похоже, функции torch.matmul еще не добавлены в библиотеку PyTorch для комплексных чисел.
Есть ли у вас какие-либо рекомендации или есть другой метод умножения сложных матриц в PyTorch?
Ответ №1:
В настоящее время torch.matmul
не поддерживается для сложных тензоров, таких как ComplexFloatTensor
но вы могли бы сделать что-то столь же компактное, как следующий код:
def matmul_complex(t1,t2):
return torch.view_as_complex(torch.stack((t1.real @ t2.real - t1.imag @ t2.imag, t1.real @ t2.imag t1.imag @ t2.real),dim=2))
По возможности избегайте использования циклов for, поскольку это приведет к гораздо более медленным реализациям.
Векторизация достигается с помощью встроенных методов, как показано в коде, который я приложил.
Например, ваш код занимает примерно 6,1 с на процессоре, в то время как векторизованная версия занимает всего 101 мс (~ в 60 раз быстрее) для 2 случайных комплексных матриц размером 1000 X 1000.
Обновить:
Начиная с PyTorch 1.7.0 (как упоминал @EduardoReis), вы можете выполнять матричное умножение между комплексными матрицами аналогично вещественным матрицам следующим образом:
t1 @ t2
(для t1
, t2
комплексных матриц).
Комментарии:
1. Недавно, используя torch
1.8.1 cu101
, я смог просто умножить два тензора наx*h
, и это дает их комплексное произведение.2. @EduardoReis Вы правы. Начиная с PyTorch 1.7.0, приведенный выше код можно сократить. Но обратите внимание, что
t1 * t2
это поточечное умножение между тензорамиt1
amp;t2
. Вы можете использоватьt1 @ t2
для получения матричного умножения, эквивалентногоmatmul_complex
. Я обновил сообщение.
Ответ №2:
Я реализовал эту функцию для pytorch.matmul для комплексных чисел, используя torch.mv и пока это работает нормально:
def matmul_complex(t1, t2):
m = list(t1.size())[0]
n = list(t2.size())[1]
t = torch.empty((1,n), dtype=torch.cfloat)
t_total = torch.empty((m,n), dtype=torch.cfloat)
for i in range(0,n):
if i == 0:
t_total = torch.mv(t1,t2[:,i])
else:
t_total = torch.cat((t_total, torch.mv(t1,t2[:,i])), 0)
t_final = torch.reshape(t_total, (m,n))
return t_final
Я новичок в PyTorch, поэтому, пожалуйста, поправьте меня, если я ошибаюсь.