Простой способ сортировки матрицы на основе другой аналогичной матрицы

#python #arrays #numpy #sorting #matrix

Вопрос:

Допустим, у меня есть матрица Y случайных чисел с плавающей запятой от 0 до 10 с формой (10, 3) :

 import numpy as np
np.random.seed(99)
Y = np.random.uniform(0, 10, (10, 3))
print(Y)
 

Выход:

 [[6.72278559 4.88078399 8.25495174]
 [0.31446388 8.08049963 5.6561742 ]
 [2.97622499 0.46695721 9.90627399]
 [0.06825733 7.69793028 7.46767101]
 [3.77438936 4.94147452 9.28948392]
 [3.95454044 9.73956297 5.24414715]
 [0.93613093 8.13308413 2.11686786]
 [5.54345785 2.92269116 8.1614236 ]
 [8.28042566 2.21577372 6.44834702]
 [0.95181622 4.11663239 0.96865261]]
 

Теперь мне предоставлена матрица X той же формы, которую можно увидеть, добавив небольшие шумы в Y строки, а затем перетасовав их:

 X = np.random.normal(Y, scale=0.1)
np.random.shuffle(X)
print(X)
 

Выход:

 [[ 4.04067271  9.90959141  5.19126867]
 [ 5.59873104  2.84109306  8.11175891]
 [ 0.10743952  7.74620162  7.51100441]
 [ 3.60396019  4.91708372  9.07551354]
 [ 0.9400948   4.15448712  1.04187208]
 [ 2.91884302  0.47222752 10.12700505]
 [ 0.30995155  8.09263241  5.74876947]
 [ 1.11247872  8.02092335  1.99767444]
 [ 6.68543696  4.8345869   8.17330513]
 [ 8.38904822  2.11830619  6.42013343]]
 

Теперь я хочу отсортировать матрицу X Y по строкам. Я уже знаю, что каждая пара значений столбцов в каждой совпадающей паре строк не отличается друг от друга более чем на 0,5 допуска. Мне удалось написать следующий код, и он работает нормально.

 def sort_X_by_Y(X, Y, tol):
    idxs = [next(i for i in range(len(X)) if all(abs(X[i] - row) <= tol)) for row in Y]
    return X[idxs]

print(sort_X_by_Y(X, Y, tol=0.5))
 

Выход:

 [[ 6.68543696  4.8345869   8.17330513]
 [ 0.30995155  8.09263241  5.74876947]
 [ 2.91884302  0.47222752 10.12700505]
 [ 0.10743952  7.74620162  7.51100441]
 [ 3.60396019  4.91708372  9.07551354]
 [ 4.04067271  9.90959141  5.19126867]
 [ 1.11247872  8.02092335  1.99767444]
 [ 5.59873104  2.84109306  8.11175891]
 [ 8.38904822  2.11830619  6.42013343]
 [ 0.9400948   4.15448712  1.04187208]]
 

Однако на самом деле я сортирую (1000, 3) матрицы, и мой код слишком медленный. Я чувствую, что должен быть более дурацкий способ кодирования этого. Есть какие-нибудь предложения?

Комментарии:

1. Если вам не повезет, вы рискуете вычислить то же ix самое для двух разных строк!

2. @Stef Да, я знаю об этом. К счастью, в моем случае все строки сильно отличаются друг от друга.

Ответ №1:

Это векторизованная версия вашего алгоритма. Он работает примерно в 26,5 раза быстрее, чем ваша реализация для 1000 образцов. Но создается дополнительный логический массив с формой (1000,1000,3) . Существует вероятность того, что строки будут иметь одинаковые значения в пределах допуска, и будет выбрана неправильная строка.

 tol = .5
X[(np.abs(Y[:, np.newaxis] - X) <= tol).all(2).argmax(1)]
 

Выход

 array([[ 6.68543696,  4.8345869 ,  8.17330513],
       [ 0.30995155,  8.09263241,  5.74876947],
       [ 2.91884302,  0.47222752, 10.12700505],
       [ 0.10743952,  7.74620162,  7.51100441],
       [ 3.60396019,  4.91708372,  9.07551354],
       [ 4.04067271,  9.90959141,  5.19126867],
       [ 1.11247872,  8.02092335,  1.99767444],
       [ 5.59873104,  2.84109306,  8.11175891],
       [ 8.38904822,  2.11830619,  6.42013343],
       [ 0.9400948 ,  4.15448712,  1.04187208]])
 

Более надежные решения с L1-нормой

 X[np.abs(Y[:, np.newaxis] - X).sum(2).argmin(1)]
 

Или L2-норма

 X[((Y[:, np.newaxis] - X)**2).sum(2).argmin(1)]
 

Комментарии:

1. Риск заключается не только в том, что может быть выбрана «неправильная» строка, но и в том, что одна и та же строка может быть выбрана более одного раза.

2. Да, в этой реализации это будет строка с наименьшим индексом. То же ограничение, что и алгоритм операции.

3. Возможно, будет немного надежнее удалить параметр допуска и учитывать только сумму разницы для каждой строки: X[np.abs(Y[:, np.newaxis] - X).sum(2).argmin(1)]

4. @obchardon Это действительно отличное предложение

5. Нет, нет, это всего лишь небольшое улучшение, мы можем сохранить все в вашем ответе. L2-norm Решение, вероятно, является лучшим, так как оно сводит к минимуму количество ошибок. Также в данном конкретном случае np.sqrt() это довольно бесполезно.