матричное умножение для комплексных чисел в PyTorch

#pytorch #matrix-multiplication #complex-numbers

#pytorch #матрица-умножение #комплексные числа

Вопрос:

Я пытаюсь умножить две сложные матрицы в PyTorch, и, похоже, функции torch.matmul еще не добавлены в библиотеку PyTorch для комплексных чисел.

Есть ли у вас какие-либо рекомендации или есть другой метод умножения сложных матриц в PyTorch?

Ответ №1:

В настоящее время torch.matmul не поддерживается для сложных тензоров, таких как ComplexFloatTensor но вы могли бы сделать что-то столь же компактное, как следующий код:

 def matmul_complex(t1,t2):
    return torch.view_as_complex(torch.stack((t1.real @ t2.real - t1.imag @ t2.imag, t1.real @ t2.imag   t1.imag @ t2.real),dim=2))
  

По возможности избегайте использования циклов for, поскольку это приведет к гораздо более медленным реализациям.
Векторизация достигается с помощью встроенных методов, как показано в коде, который я приложил.
Например, ваш код занимает примерно 6,1 с на процессоре, в то время как векторизованная версия занимает всего 101 мс (~ в 60 раз быстрее) для 2 случайных комплексных матриц размером 1000 X 1000.

Обновить:

Начиная с PyTorch 1.7.0 (как упоминал @EduardoReis), вы можете выполнять матричное умножение между комплексными матрицами аналогично вещественным матрицам следующим образом:

t1 @ t2 (для t1 , t2 комплексных матриц).

Комментарии:

1. Недавно, используя torch 1.8.1 cu101 , я смог просто умножить два тензора на x*h , и это дает их комплексное произведение.

2. @EduardoReis Вы правы. Начиная с PyTorch 1.7.0, приведенный выше код можно сократить. Но обратите внимание, что t1 * t2 это поточечное умножение между тензорами t1 amp; t2 . Вы можете использовать t1 @ t2 для получения матричного умножения, эквивалентного matmul_complex . Я обновил сообщение.

Ответ №2:

Я реализовал эту функцию для pytorch.matmul для комплексных чисел, используя torch.mv и пока это работает нормально:

 def matmul_complex(t1, t2):
  m = list(t1.size())[0]
  n = list(t2.size())[1]
  t = torch.empty((1,n), dtype=torch.cfloat)
  t_total = torch.empty((m,n), dtype=torch.cfloat)
  for i in range(0,n):
    if i == 0:
      t_total = torch.mv(t1,t2[:,i])
    else:
      t_total = torch.cat((t_total, torch.mv(t1,t2[:,i])), 0)
  t_final = torch.reshape(t_total, (m,n))
  return t_final
  

Я новичок в PyTorch, поэтому, пожалуйста, поправьте меня, если я ошибаюсь.