#matlab #matrix #dot-product
#matlab #матрица #скалярное произведение
Вопрос:
У меня есть 2 матрицы = X in R^(n*m)
и W in R^(k*m)
где k<<n
. Пусть x_i
— i-я строка X и w_j
— j-я строка W. Мне нужно найти для каждого x_i, какое j максимизирует <w_j,x_i>
Я не вижу способа обойти итерацию по всем строкам в X, но есть ли способ найти максимальное скалярное произведение без повторения каждый раз по всему W?
Наивная реализация была бы:
n = 100;
m = 50;
k = 10;
X = rand(n,m);
W = rand(k,m);
Y = zeros(n, 1);
for i = 1 : n
max_ind = 1;
max_val = dot(W(1,:), X(i,:));
for j = 2 : k
cur_val = dot(W(j,:),X(i,:));
if cur_val > max_val
max_val = cur_val;
max_ind = j;
end
end
Y(i,:) = max_ind;
end
Комментарии:
1. Поделитесь с нами повторяющимся кодом? Может быть, также добавить примерный пример?
2. Я добавил наивную реализацию, которую я могу придумать
3. 1 для воспроизводимого примера
Ответ №1:
Скалярное произведение — это, по сути, умножение матрицы:
[~, Y] = max(W*X');
Комментарии:
1. @Divakar Спасибо 🙂 Ваш
bsxfun
тоже хорош (уже 1) и занимает примерно столько же времени2. О, я сомневаюсь, что это, особенно с большими наборами данных,
bsxfun
должно замедляться.
Ответ №2:
bsxfun
основанный на этом подход к ускорению работы для вас —
[~,Y] = max(sum(bsxfun(@times,X,permute(W,[3 2 1])),2),[],3)
В моей системе, используя ваш набор данных, я получаю 100x
ускорение с этим.
Можно придумать еще два «близких» подхода, но они, похоже, не дают каких-либо значительных улучшений по сравнению с предыдущим —
[~,Y] = max(squeeze(sum(bsxfun(@times,X,permute(W,[3 2 1])),2)),[],2)
и
[~,Y] = max(squeeze(sum(bsxfun(@times,X',permute(W,[2 3 1]))))')