#python #scikit-learn
#python #scikit-learn
Вопрос:
Я пытался создать составную оценку sklearn
; Теперь я обнаружил, что sklearn.compose.TransformedTargetRegressor
это делает именно то, чего я пытался достичь, но я все еще не могу его воспроизвести, и мне любопытно, почему.
Ошибка, которую я получаю :
AssertionError: Estimator TransformedSkModel should not change or mutate the parameter model from LinearRegression() to LinearRegression() during fit.
Мой код :
import numpy as np
from sklearn.base import BaseEstimator
class TransformedSkModel(BaseEstimator):
def __init__(self, model, transform_function, reverse_transform_function):
# TODO: ideally, verify, based on specified ranges, that the transform
# function and reverse transform function are compatible
self.model = model
self.transform_function = transform_function
self.reverse_transform_function = reverse_transform_function
def fit(self, X, y):
# Trying to reproduce as sklearn compatible estimator, we must use the
# conventions:
# The constructor in an estimator should only set attributes to the
# values the user passes as arguments. All computation should occur in
# fit, and if fit needs to store the result of a computation, it should
# do so in an attribute with a trailing underscore (_). This convention
# is what makes clone and meta-estimators such as GridSearchCV work.
self.vectorized_transform_function_ =
np.vectorize(self.transform_function)
self.vectorized_reverse_transform_function_ =
np.vectorize(self.reverse_transform_function)
y_transformed = self.vectorized_transform_function_(y)
self.model.fit(X, y_transformed)
return self
def predict(self, X):
y_transformed = self.model.predict(X)
y = self.vectorized_reverse_transform_function_(y_transformed)
return y
# def get_params(self, ):
# return self.model.get_params()
if __name__ == "__main__":
from sklearn.utils.estimator_checks import check_estimator
from sklearn.linear_model import LinearRegression
lm = LinearRegression()
id_func = lambda x:x
test = TransformedSkModel(lm, id_func, id_func)
check_estimator(test)
РЕДАКТИРОВАТЬ: я использую версию sklearn 0.24.0 и версию python 3.6.8
Комментарии:
1. какую версию sklearn вы используете?
2. Привет, @ctlr; Я использую версию sklearn 0.24.0 и версию python 3.6.8
3. Здесь это похоже на проблему, я опубликую ответ, если найду решение. по сути, сравниваются ваши установленные и исходные параметры модели, и объекты с установленной моделью изучили параметры, что вызывает другой хэш и выдает ошибку
4. Да, конечно, я тоже так понимаю. Я понимаю, что параметры не должны изменяться методом подгонки; но мой субоценщик должен быть. Должно ли оно быть помечено по-другому? Пока я думаю об этом, просмотр кода
sklearn.compose.TransformedTargetRegressor
может помочь, поскольку он делает то, что я хочу, без каких-либо проблем5. Здесь клонируется исходная оценка, а клонированная подгоняется и используется для прогнозирования. После этого я получаю некоторые ошибки из-за id_func
Ответ №1:
Итак, глядя на sklearn.compose.TransformedTargetRegressor
(который делает то, что я хочу), ключ, похоже, копирует мое model
использование sklearn.base.clone
и соответствует новому self.model_
(с подчеркиванием в trail в соответствии с соглашением). Таким образом, новый код для моего метода fit становится:
def fit(self, X, y):
# Trying to reproduce as sklearn compatible estimator, we must use the
# conventions:
# The constructor in an estimator should only set attributes to the
# values the user passes as arguments. All computation should occur in
# fit, and if fit needs to store the result of a computation, it should
# do so in an attribute with a trailing underscore (_). This convention
# is what makes clone and meta-estimators such as GridSearchCV work.
self.vectorized_transform_function_ =
np.vectorize(self.transform_function)
self.vectorized_reverse_transform_function_ =
np.vectorize(self.reverse_transform_function)
y_transformed = self.vectorized_transform_function_(y)
self.model_ = clone(self.model)
self.model_.fit(X, y_transformed)
Теперь я получаю другую ошибку, связанную с использованием np.vectorize
, но это еще одна проблема, которую, я думаю, можно решить в другом вопросе.