#python #machine-learning #scikit-learn
#питон #машинное обучение #scikit-учись
Вопрос:
Я пытаюсь выполнить поиск по сетке с помощью scikit-learn и GridSearchCV. Для перекрестной проверки я разделяю свои данные с помощью StratifiedShuffleSplit.
Если я использую такой подход, как Случайный лес, все работает хорошо. Однако при распространении этикеток или распространении этикеток я сталкиваюсь с проблемой подсчета очков. Это связано с тем, что некоторые метки в моих обучающих данных необходимо преобразовать в класс без меток (-1).
Допустим, мои обучающие данные состоят из 6 выборок с метками 0 или 1.
labels = [0, 0, 0, 1, 1, 1]
Я притворяюсь, что некоторые из них не помечены во время обучения, преобразуя их в -1
labels_w_some_unlabelled = [0, 0, -1, 1, 1, -1]
Используя StratifiedShuffleSplit, предположим, что мои данные разделены в первом сгибе на следующие индексы:
train_indices = [0, 1, 3] test_indices = [2, 4, 5] test_labels = [-1, 1, -1]
Когда я запускаю поиск по сетке и мои данные оцениваются, основная истина для меток содержит классы без меток
cv = StratifiedShuffleSplit() grid_search = GridSearchCV( estimator='LabelPropagation', cv=cv, ... ) grid_search.fit(training_data, test_labels)
Любая из тестовых меток с классом без метки не может получить правильное предсказание, потому что их «основная истина» для функции оценки будет классом -1/без метки.
Я хочу создать пользовательскую функцию подсчета очков и передать ее GridSearchCV
, но, похоже, индексы, используемые в каждом сгибе, не передаются в эту функцию, поэтому я не могу получить правильные метки для чего-либо в наборе тестов.
def custom_scoring_function(estimator, X, y): ''' X has 3 samples for the test data. y is the ground truth, but contains -1 for unlabelled data and isn't the ground truth I want. ''' groud_truth = [] # extract correct labels using indices of test data somehow y_pred = estimator.predict(X) return accuracy_score(ground_truth, y_pred) cv = StratifiedShuffleSplit() grid_search = GridSearchCV( estimator='LabelPropagation', cv=cv, scoring=custom_scoring_function ... ) grid_search.fit(training_data, test_labels)
Я мог бы просто создать свою собственную функцию поиска по сетке, но подумал, что может быть другой способ.