#python #regression #cross-validation
#питон #регрессия #перекрестная проверка
Вопрос:
Как выполнить перекрестную проверку с помощью регрессионной модели GLM?
Я создал модель glm sm.GLM(endog, exog, family=sm.families.Gamma(link=sm.families.links.log())).fit()
, и мне нужно будет перепроверить результат, однако я не могу найти способ сделать это с sm.GLM
помощью модели. Найдено несколько примеров, где model = LogisticRegression()
используется, но это неприменимо к моим данным.
Вот код:
import pandas as pd import statsmodels.api as sm from sklearn.model_selection import train_test_split, cross_val_score from sklearn.model_selection import KFold Test = pd.read_csv(r'D:myfile.csv') endog = Test['Y'] exog = Test[['log_X1', 'log_A', 'log_B']] glm_model = sm.GLM(endog, exog, family=sm.families.Gaussian(link=sm.families.links.log())).fit() y_pred = glm_model.predict() scoring = "neg_root_mean_squared_error" X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=1) crossvalidation = KFold(n_splits=10) scores = cross_val_score(glm_model, X_train, y_train, scoring="neg mean_squared_error", cv=crossvalidation)
С конкретной строкой я получаю ошибку. Возможно, есть другие способы, как это сделать?
scores = cross_val_score(glm_model, X_train, y_train, scoring="neg mean_squared_error", cv=crossvalidation) TypeError: estimator should be an estimator implementing 'fit' method, lt;statsmodels.genmod.generalized_linear_model.GLMResultsWrapper object at 0x000002972A2181F0gt; was passed
Ответ №1:
Ответ-SMWrapper:
import statsmodels.api as sm from sklearn.base import BaseEstimator, RegressorMixin class SMWrapper(BaseEstimator, RegressorMixin): """ A universal sklearn-style wrapper for statsmodels regressors """ def __init__(self, model_class, fit_intercept=True): self.model_class = model_class self.fit_intercept = fit_intercept def fit(self, X, y): if self.fit_intercept: X = sm.add_constant(X) self.model_ = self.model_class(y, X) self.results_ = self.model_.fit() return self def predict(self, X): if self.fit_intercept: X = sm.add_constant(X) return self.results_.predict(X)