#python #matrix #confusion-matrix
Вопрос:
Я пытаюсь сохранить несколько экземпляров матрицы путаницы для каждой складки в 10-кратной перекрестной проверке. мой код выглядит так :
kf = KFold(n_splits=10, shuffle=True)
conf_matrix = np.zeros([3,3])
for train_index, test_index in kf.split(X):
X_train, X_test = X.iloc[train_index], X.iloc[test_index]
y_train, y_test = y.iloc[train_index], y.iloc[test_index]
model.fit(X_train, y_train)
pred = model.predict(X_test)
conf_matrix[train_index] = confusion_matrix(y_test, pred)
conf_matrix
но он возвращает эту ошибку
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-39-8a1e0e2223f7> in <module>()
9 model.fit(X_train, y_train)
10 pred = model.predict(X_test)
---> 11 conf_matrix[train_index] = confusion_matrix(y_test, pred)
12 # print(confusion_matrix(y_test, pred))
13
ValueError: shape mismatch: value array of shape (3,3) could not be broadcast to indexing result of shape (270,3)
что странно, потому что если я запущу print(confusion_matrix(y_test, pred))
выходную матрицу 3×3, которая выглядит так :
[[4 2 3]
[4 8 1]
[1 3 4]]
[[3 6 1]
[4 5 1]
[5 1 4]]
[[7 3 4]
[4 4 1]
[0 1 6]]
[[3 1 2]
[2 9 1]
[3 2 7]]
[[4 4 2]
[2 8 2]
[0 2 6]]
[[7 6 1]
[4 4 2]
[0 3 3]]
[[4 1 3]
[2 5 3]
[3 1 8]]
[[1 4 3]
[2 5 1]
[2 3 9]]
[[7 5 2]
[2 3 3]
[2 3 3]]
[[4 2 1]
[2 6 0]
[3 5 7]]
где я ошибся?
Правка 2 : Я попробовал предложение @user1424589 и изменил свой код вот так
conf_matrix = np.zeros((10,len(y_test),len(y_test))) #10 because I want it to caontain 10 confusion matrix
но он все равно возвращает ту же ошибку, теперь он возвращает эту ошибку
ValueError: shape mismatch: value array of shape (3,3) could not be broadcast to indexing result of shape (270,30,30)
Ответ №1:
Проблема в том, что то, что вы сказали ему сделать, — это назначить этот выходной массив в срез conf_matrix, который имеет форму (3). Вам нужно сформировать conf_matrix таким образом, чтобы он мог принимать результат confusion_matrix(y_test, pred).
Комментарии:
1. и как я могу это сделать? так как я получил значение только для (y_test, pred) внутри моего цикла for