#python #arrays #matlab #broadcasting
#python #массивы #matlab #вещание
Вопрос:
Я пытаюсь перенести некоторый код из MATLAB в python, и мне трудно понять, почему следующий код не работает.
import numpy as np
ngrid = 56
A = np.random.randint(10, size =(ngrid*ngrid,2))
A_tmp = A
B = np.random.randint(10,size =(ngrid*ngrid,2,2) )
for jj in range(ngrid*ngrid):
A[jj,:] = A_tmp[jj,:]*B[jj,:,:].conj()
Когда я выполняю этот код, я получаю сообщение об ошибке.
ValueError: could not broadcast input array from shape (2,2) into shape (2,)
Я не понимаю, почему это дает мне одномерный массив по сравнению с (ngrid*ngrid,2)
массивом.
Код MATLAB, который я пытаюсь воссоздать, является
for jj = 1:ngrid^2
Psi0(jj, :) = Psi0_tmp(jj, :)*dia2adi(:,:, jj)';
end
Любые рекомендации по теории и способам исправления моего кода были бы очень полезны.
Спасибо
Комментарии:
1. Это дает вам 1-мерный массив, потому что нарезка 1-мерного фрагмента массива numpy представляет собой 1-мерный массив. Сравните это с MATLAB, где одномерных массивов не существует: все, по крайней мере, 2d.
2. Мне уже слишком поздно пытаться придумать элегантное (читай: векторизованное) решение, но в худшем случае вы можете использовать что-то вроде
A = np.einsum('ja,jab -> jb', A_tmp, B.conj())
вместо всего цикла. Дайте или примите транспонирование, которое подразумевается простым оператором MATLAB.3. Я слышал об этом «волшебном» einsum. Спасибо за помощь, проблема решена.
4. Я думаю, вы не сможете выполнить это с помощью matmul, поэтому единственный другой вариант
A = (A_tmp[..., None] * B.conj()).sum(1)
. Вероятно, это будет быстрее, чем einsum, за счет увеличения объема памяти.5. Спасибо, я приму это во внимание, когда буду оптимизировать код
Ответ №1:
В numpy символ умножения всегда является поэлементным (как .*
в Matlab). Для умножения матриц используйте @
.
Кроме того, вы можете выполнять возведение в степень в Python с **
помощью .
Поэтому, если вы измените цикл на приведенный ниже код, это сработает.
for jj in range(ngrid**2):
A[jj,:] = A_tmp[jj,:] @ B[jj,:,:].conj()