Внешнее сложение по строкам

#python #numpy

Вопрос:

Даны два массива numpy: A формы (m, k) и B формы (m, n). Я хотел бы вычислить массив C формы (m, k, n), где каждая строка r из C содержит внешнее дополнение строки r из A и строки r из B. Я могу сделать это, используя цикл for следующим образом:

  import numpy as np

 A = np.array([[ 0,  1,  2,  3,  4],
               [ 5,  6,  7,  8,  9],
               [10, 11, 12, 13, 14]])

 B = np.array([[1., 1., 1., 1., 1., 1.],
               [2., 2., 2., 2., 2., 2.],
               [3., 3., 3., 3., 3., 3.]])
 A.shape
 Out[655]: (3, 5)

 B.shape
 Out[656]: (3, 6)

C = np.zeros((A.shape[0], A.shape[1],B.shape[1]))
for i in range(A.shape[0]):
    C[i] = A[i][:,None]  B[i]

C
Out[659]: 
array([[[ 1.,  1.,  1.,  1.,  1.,  1.],
        [ 2.,  2.,  2.,  2.,  2.,  2.],
        [ 3.,  3.,  3.,  3.,  3.,  3.],
        [ 4.,  4.,  4.,  4.,  4.,  4.],
        [ 5.,  5.,  5.,  5.,  5.,  5.]],

       [[ 7.,  7.,  7.,  7.,  7.,  7.],
        [ 8.,  8.,  8.,  8.,  8.,  8.],
        [ 9.,  9.,  9.,  9.,  9.,  9.],
        [10., 10., 10., 10., 10., 10.],
        [11., 11., 11., 11., 11., 11.]],

       [[13., 13., 13., 13., 13., 13.],
        [14., 14., 14., 14., 14., 14.],
        [15., 15., 15., 15., 15., 15.],
        [16., 16., 16., 16., 16., 16.],
        [17., 17., 17., 17., 17., 17.]]])
 

Но есть ли способ векторизовать приведенный выше код, чтобы избавиться от цикла for?

Ответ №1:

Вы могли бы использовать трюки с трансляцией, A имеет форму (m, k) и B имеет форму (m, n) , вы хотите вставить измерение A и B противоположными способами, чтобы полученная форма была (m, k, 1) для одного и (m, 1, n) для другого. Затем применяющий оператор выполнит внешнюю операцию:

 >>> A[...,None]   B[:,None]
array([[[ 1.,  1.,  1.,  1.,  1.,  1.],
        [ 2.,  2.,  2.,  2.,  2.,  2.],
        [ 3.,  3.,  3.,  3.,  3.,  3.],
        [ 4.,  4.,  4.,  4.,  4.,  4.],
        [ 5.,  5.,  5.,  5.,  5.,  5.]],

       [[ 7.,  7.,  7.,  7.,  7.,  7.],
        [ 8.,  8.,  8.,  8.,  8.,  8.],
        [ 9.,  9.,  9.,  9.,  9.,  9.],
        [10., 10., 10., 10., 10., 10.],
        [11., 11., 11., 11., 11., 11.]],

       [[13., 13., 13., 13., 13., 13.],
        [14., 14., 14., 14., 14., 14.],
        [15., 15., 15., 15., 15., 15.],
        [16., 16., 16., 16., 16., 16.],
        [17., 17., 17., 17., 17., 17.]]])