#r #r-caret
#r #r-каретка
Вопрос:
Я пытаюсь понять, как создать матрицу путаницы после того, как я использую функцию glm для логистической регрессии. Вот мой код на данный момент. Я использую пакет каретки и функцию confusionMatrix.
dput(head(wine_quality))
structure(list(fixed.acidity = c(7, 6.3, 8.1, 7.2, 7.2, 8.1),
volatile.acidity = c(0.27, 0.3, 0.28, 0.23, 0.23, 0.28),
citric.acid = c(0.36, 0.34, 0.4, 0.32, 0.32, 0.4), residual.sugar = c(20.7,
1.6, 6.9, 8.5, 8.5, 6.9), chlorides = c(0.045, 0.049, 0.05,
0.058, 0.058, 0.05), free.sulfur.dioxide = c(45, 14, 30,
47, 47, 30), total.sulfur.dioxide = c(170, 132, 97, 186,
186, 97), density = c(1.001, 0.994, 0.9951, 0.9956, 0.9956,
0.9951), pH = c(3, 3.3, 3.26, 3.19, 3.19, 3.26), sulphates = c(0.45,
0.49, 0.44, 0.4, 0.4, 0.44), alcohol = c(8.8, 9.5, 10.1,
9.9, 9.9, 10.1), quality = structure(c(4L, 4L, 4L, 4L, 4L,
4L), .Label = c("3", "4", "5", "6", "7", "8", "9", "white"
), class = "factor"), type = structure(c(3L, 3L, 3L, 3L,
3L, 3L), .Label = c("", "red", "white"), class = "factor"),
numeric_type = c(0, 0, 0, 0, 0, 0)), row.names = c(NA, 6L
), class = "data.frame")
library(tibble)
library(broom)
library(ggplot2)
library(caret)
any(is.na(wine_quality)) # this evaulates to FALSE
wine_model <- glm(type ~ fixed.acidity volatile.acidity citric.acid residual.sugar chlorides free.sulfur.dioxide total.sulfur.dioxide density pH sulphates alcohol, wine_quality, family = "binomial")
# split data into test and train
smp_size <- floor(0.75 * nrow(wine_quality))
set.seed(123)
train_ind <- sample(seq_len(nrow(wine_quality)), size = smp_size)
train <- wine_quality[train_ind, ]
test <- wine_quality[-train_ind, ]
# make prediction on train data
pred <- predict(wine_model)
train$fixed.acidity <- as.numeric(train$fixed.acidity)
round(train$fixed.acidity)
train$fixed.acidity <- as.factor(train$fixed.acidity)
pred <- as.numeric(pred)
round(pred)
pred <- as.factor(pred)
confusionMatrix(pred, wine_quality$fixed.acidity)
После этой последней строки кода я получаю эту ошибку:
Error: `data` and `reference` should be factors with the same levels.
Эта ошибка не имеет смысла для меня. Я проверил, что длина pred и длина fixed.acidity одинаковы (6497), а также они оба являются типом данных фактора.
length(pred)
length(wine_quality$fixed.acidity)
class(pred)
class(train$fixed.acidity)
Есть ли какая-либо очевидная причина, по которой эта матрица путаницы не работает? Я пытаюсь найти коэффициент соответствия для модели. Я был бы признателен за фиктивные объяснения, я действительно не знаю, что я здесь делаю.
Ответ №1:
Ошибка от confusionMatrix()
сообщает нам, что две переменные, переданные функции, должны быть факторами с одинаковыми значениями. Мы можем понять, почему мы получили ошибку при запуске str()
с обеими переменными.
> str(pred)
Factor w/ 5318 levels "-23.6495182533792",..: 310 339 419 1105 310 353 1062 942 594 1272 ...
> str(wine_quality$fixed.acidity)
num [1:6497] 7.4 7.8 7.8 11.2 7.4 7.4 7.9 7.3 7.8 7.5 ...
pred
является фактором, когда wine_quality$fixed_acidity
является числовым вектором. confusionMatrix()
Функция используется для сравнения прогнозируемых и фактических значений зависимой переменной. Он не предназначен для перекрестного табулирования прогнозируемой переменной и независимой переменной.
Код в вопросе используется fixed.acidity
в матрице путаницы, когда он должен сравнивать прогнозируемые значения type
с фактическими значениями type
из данных тестирования.
Кроме того, код в вопросе создает модель перед разделением данных на тестовые и обучающие данные. Правильная процедура заключается в разделении данных перед построением модели на основе обучающих данных, создании прогнозов с использованием данных тестирования (сдерживания) и сравнении фактических данных с прогнозами в данных тестирования.
Наконец, результатом predict()
функции, закодированной в исходном сообщении, являются линейные прогнозируемые значения из модели GLM (эквивалентные wine_model$linear.predictors
в объекте выходной модели). Эти значения должны быть дополнительно преобразованы, чтобы сделать их подходящими перед использованием confusionMatrix()
.
На практике его проще использовать caret::train()
с методом GLM и биномиальным семейством, где predict()
будут генерироваться результаты, которые можно использовать confusionMatrix()
. Мы проиллюстрируем это данными UCI о качестве вина.
Сначала мы загружаем данные из репозитория машинного обучения UCI, чтобы сделать пример воспроизводимым.
download.file("https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv",
"./data/wine_quality_red.csv")
download.file("https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv",
"./data/wine_quality_white.csv")
Во-вторых, мы загружаем данные, присваиваем type
им значение красного или белого цвета в зависимости от файла данных и связываем данные в единый фрейм данных.
red <- read.csv("./data/wine_quality_red.csv",header = TRUE,sep=";")
white <- read.csv("./data/wine_quality_white.csv",header = TRUE,sep=";")
red$type <- "red"
white$type <- "white"
wine_quality <- rbind(red,white)
wine_quality$type <- factor(wine_quality$type)
Затем мы разделяем данные на тестовые и обучающие на основе значений type
, поэтому каждый фрейм данных получает пропорциональное количество красных и белых вин, обучаем данные caret::train()
настройкам по умолчанию и методу GLM.
library(caret)
set.seed(123)
inTrain <- createDataPartition(wine_quality$type, p = 3/4)[[1]]
training <- wine_quality[ inTrain,]
testing <- wine_quality[-inTrain,]
aModel <- train(type ~ .,data = training, method="glm", familia's = "binomial")
Наконец, мы используем модель для прогнозирования фрейма данных с задержкой и запускаем матрицу путаницы.
testLM <- predict(aModel,testing)
confusionMatrix(data=testLM,reference=testing$type)
… и вывод:
> confusionMatrix(data=testLM,reference=testing$type)
Confusion Matrix and Statistics
Reference
Prediction red white
red 393 3
white 6 1221
Accuracy : 0.9945
95% CI : (0.9895, 0.9975)
No Information Rate : 0.7542
P-Value [Acc > NIR] : <2e-16
Kappa : 0.985
Mcnemar's Test P-Value : 0.505
Sensitivity : 0.9850
Specificity : 0.9975
Pos Pred Value : 0.9924
Neg Pred Value : 0.9951
Prevalence : 0.2458
Detection Rate : 0.2421
Detection Prevalence : 0.2440
Balanced Accuracy : 0.9913
'Positive' Class : red