#gridsearchcv #svc
Вопрос:
Я пытаюсь настроить гиперпараметры для классификатора SVM и использую набор данных о раке молочной железы в Висконсине, доступный по адресу: https://archive.ics.uci.edu/ml/datasets/breast рак висконсин (оригинальный)
Мой код следующий:
import numpy as np
from sklearn.impute import SimpleImputer
from sklearn.svm import SVC
import pandas as pd
from sklearn.model_selection import GridSearchCV
def load():
data2=pd.read_csv("breast-cancer-wisconsin.data",names=['id', 'clump_thickness','unif_cell_size','unif_cell_shape', 'marg_adhesion', 'single_epith_cell_size','bare_nuclei', 'bland_chromatin', 'normal_nucleoli','mitoses','class'])
y=data2["class"]
listDrop=['id','class']
data2=data2.drop(listDrop,axis="columns")
data2.replace('?',np.nan,inplace=True)
imp=SimpleImputer(missing_values=np.NaN,strategy='mean')
data2=pd.DataFrame(imp.fit_transform(data2))
X=data2
svm(X,y)
def svm(X,y):
X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2)
kernels=["rbf"]
c=[-5,1,5]
gamma=[1, 0.1, 0.01, 0.001]
model=SVC()
grid=GridSearchCV(model,param_grid={'kernel':kernels,'C':c,'gamma':gamma},cv=5)
grid.fit(X_train,y_train)
print("best params",grid.best_params_)
Код выполняется, и я получаю наилучшие гиперпараметры; однако у меня есть следующие предупреждения:
FitFailedWarning: Estimator fit failed. The score on this train-test partition for these parameters will be set to nan.
warnings.warn("Estimator fit failed. The score on this train-test"
Я проверил в других сообщениях здесь, и там говорится, что это может произойти, например, при использовании текстовых данных, когда могут быть нечисловые данные. Я проверил это с помощью: print(X_train.info())
и я вижу, что все мои данные являются числовыми:
Int64Index: 559 entries, 253 to 446
Data columns (total 9 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 0 559 non-null float64
1 1 559 non-null float64
2 2 559 non-null float64
3 3 559 non-null float64
4 4 559 non-null float64
5 5 559 non-null float64
6 6 559 non-null float64
7 7 559 non-null float64
8 8 559 non-null float64
dtypes: float64(9)
Что я делаю не так?