#python #matplotlib #machine-learning #scikit-learn #logistic-regression
Вопрос:
У меня есть блок кода, который пытается создать модель логистической регрессии и построить пути регуляризации модели. Он просто извлекает некоторые разреженные данные, преобразует их в наборы для обучения и тестирования, а затем пытается построить пути регуляризации.
Я думаю, что вся моя логика верна, однако я никогда не оставляю вызов matplotlib для сюжета plt.plot(np.log10(min_c), LogRegCoefs)
. Кажется, что это длится вечно и никогда не возвращается.
Есть ли где-нибудь ошибка в моей математике? Есть ли способ сделать это, который действительно сработал бы?
from sklearn.svm import l1_min_c
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.model_selection import train_test_split
x, y = fetch_20newsgroups_vectorized(subset='all', return_X_y=True)
#ignore sklearn's warning for a nicer output
def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
#regrab our same dataset from above
x_train, x_test, y_train, y_test = train_test_split(x, y,test_size=.1, random_state=42,stratify=y)
#get lowest bound for c, returns a list
#default argument is L2, must specify log
min_c = l1_min_c(x_train, y_train, loss='log') * np.logspace(0, 3)
logRegression = LogisticRegression(max_iter=1,random_state=42, solver="saga",multi_class="multinomial",penalty='l1')
#loop through c values and retrieve coefficients for each
LogRegCoefs = []
for i in range(len(min_c)):
currentC = min_c[i]
# set this C value, then refit and copy the parameters
logRegression.set_params(C=currentC)
# fit to current data
logRegression.fit(x_train, y_train)
#after data is fit, get numpy data array with coefficients
coefficients = logRegression.coef_
#store in coefficients array
LogRegCoefs.append(coefficients.ravel().copy())
# generate a plot
LogRegCoefs = np.array(LogRegCoefs)
plt.plot(np.log10(min_c), LogRegCoefs)
ymin, ymax = plt.ylim()
plt.xlabel('log(C)')
plt.ylabel('Coefficients')
plt.title('Logistic Regression Path')
plt.axis('tight')
plt.show()