Может ли функция pred() классификатора sklearn knn принимать разреженную матрицу scipy в качестве входных данных?

#python #scikit-learn #scipy #sparse-matrix #knn

Вопрос:

Я работаю над большим набором данных, поэтому я хранил данные в разреженной матрице из SciPy (https://docs.scipy.org/doc/scipy/reference/sparse.html). Это занимает меньше места в памяти, чем при использовании массива numpy. Когда я использовал его в классификаторе KNN ScikitLearn, на этапе, когда функция pred() принимала разреженную матрицу в качестве входных данных, я получил следующую ошибку:

Ошибка оси: ось 1 выходит за пределы массива измерения 1

(Обратите внимание, что вам необходимо установить метрику= «предварительно вычисленный» в классификаторе knn, чтобы использовать разреженную матрицу.)

Однако, когда я изменил разреженную матрицу на массив numpy, это просто сработало. (Предположим, что разреженная матрица-это sp_mat, я просто изменил ее на sp_mat.toarray().) Это нормально использовать массив numpy, когда я пытался использовать часть данных во время отладки. Но со всем набором данных, который я использую, мне нужно будет использовать разреженную матрицу. Просто интересно, есть ли у кого-нибудь идеи, как правильно использовать разреженную матрицу в классификаторе knn.

Код:

 sparse_train = sparse_mat.tocsr()[0:num_train,:].tocsc()[:,0:num_train]  
sparse_test = sparse_mat.tocsr()[num_train:(num_train num_val),:].tocsc()[:,0:num_train]  
neigh_dist = KNeighborsClassifier(n_neighbors=nn, weights='distance', metric='precomputed')  
neigh_dist.fit(sparse_train, y_train)  
y_pred = neigh_dist.predict(sparse_test)
 

Ошибка:

 ---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/global/software/sl-7.x86_64/modules/langs/python/3.6/lib/python3.6/site-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     55     try:
---> 56         return getattr(obj, method)(*args, **kwds)
     57 

/global/software/sl-7.x86_64/modules/langs/python/3.6/lib/python3.6/site-packages/scipy/sparse/base.py in __getattr__(self, attr)
    646         else:
--> 647             raise AttributeError(attr   " not found")
    648 

AttributeError: argpartition not found

During handling of the above exception, another exception occurred:

AxisError                                 Traceback (most recent call last)
<ipython-input-37-31f5bd405101> in <module>()
----> 1 y_pred = neigh_dist.predict(sparse_test)

/global/software/sl-7.x86_64/modules/langs/python/3.6/lib/python3.6/site-packages/sklearn/neighbors/classification.py in predict(self, X)
    143         X = check_array(X, accept_sparse='csr')
    144 
--> 145         neigh_dist, neigh_ind = self.kneighbors(X)
    146 
    147         classes_ = self.classes_

/global/software/sl-7.x86_64/modules/langs/python/3.6/lib/python3.6/site-packages/sklearn/neighbors/base.py in kneighbors(self, X, n_neighbors, return_distance)
    361                     **self.effective_metric_params_)
    362 
--> 363             neigh_ind = np.argpartition(dist, n_neighbors - 1, axis=1)
    364             neigh_ind = neigh_ind[:, :n_neighbors]
    365             # argpartition doesn't guarantee sorted order, so we sort again

/global/software/sl-7.x86_64/modules/langs/python/3.6/lib/python3.6/site-packages/numpy/core/fromnumeric.py in argpartition(a, kth, axis, kind, order)
    806 
    807     """
--> 808     return _wrapfunc(a, 'argpartition', kth, axis=axis, kind=kind, order=order)
    809 
    810 

/global/software/sl-7.x86_64/modules/langs/python/3.6/lib/python3.6/site-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     64     # a downstream library like 'pandas'.
     65     except (AttributeError, TypeError):
---> 66         return _wrapit(obj, method, *args, **kwds)
     67 
     68 

/global/software/sl-7.x86_64/modules/langs/python/3.6/lib/python3.6/site-packages/numpy/core/fromnumeric.py in _wrapit(obj, method, *args, **kwds)
     44     except AttributeError:
     45         wrap = None
---> 46     result = getattr(asarray(obj), method)(*args, **kwds)
     47     if wrap:
     48         if not isinstance(result, mu.ndarray):

AxisError: axis 1 is out of bounds for array of dimension 1
 

Комментарии:

1. Можете ли вы привести минимальный пример кода, который не работает?

2. Текущий код, по-видимому, принимает csr разреженные входные данные формата. Попробуйте преобразовать и/или предоставить полную обратную трассировку ошибок.

3. @CJR Я только что добавил код и ошибку в сообщение. Спасибо!

4. @BenReiniger Я только что добавил код и ошибку в сообщение. Спасибо!

5. @CJR Ты прав! Я только что обновил свой scikit-learn с 0.19.1 до 0.24.2, и теперь он работает! Спасибо!