Matlab: Argmax и скалярное произведение для каждой строки в матрице

#matlab #matrix #dot-product

#matlab #матрица #скалярное произведение

Вопрос:

У меня есть 2 матрицы = X in R^(n*m) и W in R^(k*m) где k<<n . Пусть x_i — i-я строка X и w_j — j-я строка W. Мне нужно найти для каждого x_i, какое j максимизирует <w_j,x_i>

Я не вижу способа обойти итерацию по всем строкам в X, но есть ли способ найти максимальное скалярное произведение без повторения каждый раз по всему W?

Наивная реализация была бы:

 n = 100;
m = 50;
k = 10;
X = rand(n,m);
W = rand(k,m);
Y = zeros(n, 1);

for i = 1 : n
  max_ind = 1;
  max_val = dot(W(1,:), X(i,:));
  for j = 2 : k
       cur_val = dot(W(j,:),X(i,:));

       if cur_val > max_val
          max_val = cur_val;
          max_ind = j;
       end

   end

   Y(i,:) = max_ind;
end
  

Комментарии:

1. Поделитесь с нами повторяющимся кодом? Может быть, также добавить примерный пример?

2. Я добавил наивную реализацию, которую я могу придумать

3. 1 для воспроизводимого примера

Ответ №1:

Скалярное произведение — это, по сути, умножение матрицы:

 [~, Y] = max(W*X');
  

Комментарии:

1. @Divakar Спасибо 🙂 Ваш bsxfun тоже хорош (уже 1) и занимает примерно столько же времени

2. О, я сомневаюсь, что это, особенно с большими наборами данных, bsxfun должно замедляться.

Ответ №2:

bsxfun основанный на этом подход к ускорению работы для вас —

 [~,Y] = max(sum(bsxfun(@times,X,permute(W,[3 2 1])),2),[],3)
  

В моей системе, используя ваш набор данных, я получаю 100x ускорение с этим.


Можно придумать еще два «близких» подхода, но они, похоже, не дают каких-либо значительных улучшений по сравнению с предыдущим —

 [~,Y] = max(squeeze(sum(bsxfun(@times,X,permute(W,[3 2 1])),2)),[],2)
  

и

 [~,Y] = max(squeeze(sum(bsxfun(@times,X',permute(W,[2 3 1]))))')