Как построить матрицу путаницы при перекрестной проверке с k-NN в R

#r #r-caret #nearest-neighbor #rpart

#r #r-каретка #ближайший сосед #rpart

Вопрос:

Я пытался задать этот вопрос на другом форуме, но, не получив ответа, я переформулирую его здесь и немного конкретизирую свою проблему. У меня есть набор данных, который выглядит следующим образом:

 > head(knnresults)
   ACTIVITY_X ACTIVITY_Y ACTIVITY_Z classification
1:         40         47         62        Feeding
2:         60         74         95       Standing
3:         62         63         88       Standing
4:         60         56         82       Standing
5:         66         61         90       Standing
6:         60         53         80       Standing
  

Со столбцом classification , имеющим три разные категории Feeding , Standing и Foraging .

Сейчас я выбираю оптимальное k значение, поэтому я классифицирую 20% данных, используя остальные 80% в качестве обучения. Классификация основана на значениях в первых трех столбцах. k Значение, показывающее наивысшую точность, будет выбрано для будущего анализа классификации.

Вот сценарий, который я использовал для этого вопроса:

 library(ISLR)
library(caret)
library(lattice)
library(ggplot2)

# Split the data for cross validation:
indxTrain <- createDataPartition(y = knnresults$classification,p = 0.8,list = FALSE)
training <- knnresults[indxTrain,]
testing <- knnresults[-indxTrain,]

# Run k-NN:
set.seed(400)
ctrl <- trainControl(method="repeatedcv",repeats = 3)
knnFit <- train(classification ~ ., data = training, method = "knn", trControl = ctrl, preProcess = c("center","scale"),tuneLength = 20)
knnFit

#Plotting different k values against accuracy (based on repeated cross validation)
plot(knnFit)
  

Во-первых, позвольте мне извиниться, поскольку я новичок в R, и я не уверен в законности этого сценария. Я буду очень рад принять любое предложение об исправлении в случае обнаружения ошибок.

Во-вторых, как мне получить доступ к матрице путаницы классификации на основе этого кода? Это важно для вычисления показателей производительности, связанных с классификацией.

Я могу dput() использовать свой набор данных ниже, если это может помочь:

 > dput(knnresults)
structure(list(ACTIVITY_X = c(40L, 60L, 62L, 60L, 66L, 60L, 57L, 
54L, 52L, 93L, 80L, 14L, 61L, 51L, 40L, 20L, 21L, 5L, 53L, 48L, 
73L, 73L, 21L, 29L, 63L, 59L, 57L, 51L, 53L, 67L, 72L, 74L, 70L, 
60L, 74L, 85L, 77L, 68L, 58L, 80L, 34L, 45L, 34L, 60L, 75L, 62L, 
66L, 51L, 53L, 48L, 62L, 62L, 57L, 5L, 1L, 12L, 23L, 5L, 4L, 
0L, 13L, 45L, 44L, 31L, 68L, 88L, 43L, 70L, 18L, 83L, 71L, 67L, 
75L, 74L, 49L, 90L, 44L, 64L, 57L, 22L, 29L, 52L, 37L, 32L, 120L, 
45L, 22L, 54L, 30L, 9L, 27L, 14L, 3L, 29L, 12L, 61L, 60L, 29L, 
15L, 7L, 6L, 0L, 2L, 0L, 4L, 1L, 7L, 0L, 0L, 0L, 0L, 0L, 1L, 
23L, 49L, 46L, 8L, 31L, 45L, 60L, 37L, 61L, 52L, 51L, 38L, 86L, 
60L, 41L, 43L, 40L, 42L, 42L, 48L, 64L, 71L, 59L, 0L, 27L, 12L, 
3L, 0L, 0L, 8L, 21L, 6L, 2L, 7L, 4L, 3L, 3L, 46L, 46L, 59L, 53L, 
37L, 44L, 39L, 49L, 37L, 47L, 17L, 36L, 32L, 33L, 26L, 12L, 8L, 
31L, 35L, 27L, 27L, 24L, 17L, 35L, 39L, 28L, 54L, 5L, 0L, 0L, 
0L, 0L, 17L, 22L, 25L, 12L, 0L, 5L, 41L, 51L, 66L, 39L, 32L, 
53L, 43L, 40L, 44L, 45L, 48L, 51L, 41L, 45L, 39L, 46L, 59L, 31L, 
5L, 24L, 18L, 5L, 15L, 13L, 0L, 26L, 0L), ACTIVITY_Y = c(47L, 
74L, 63L, 56L, 61L, 53L, 40L, 41L, 49L, 32L, 54L, 13L, 99L, 130L, 
38L, 14L, 6L, 5L, 94L, 96L, 38L, 43L, 29L, 47L, 66L, 47L, 38L, 
31L, 36L, 35L, 38L, 72L, 54L, 44L, 45L, 51L, 80L, 48L, 39L, 85L, 
42L, 39L, 37L, 75L, 36L, 45L, 32L, 35L, 41L, 26L, 99L, 163L, 
124L, 0L, 0L, 24L, 37L, 0L, 6L, 0L, 29L, 29L, 26L, 27L, 54L, 
147L, 82L, 98L, 12L, 83L, 97L, 104L, 128L, 81L, 42L, 102L, 60L, 
79L, 58L, 15L, 14L, 75L, 75L, 40L, 130L, 40L, 13L, 54L, 42L, 
7L, 10L, 3L, 0L, 15L, 8L, 75L, 55L, 26L, 18L, 1L, 13L, 0L, 0L, 
0L, 1L, 0L, 4L, 0L, 0L, 0L, 0L, 0L, 0L, 17L, 45L, 38L, 10L, 31L, 
52L, 36L, 65L, 97L, 45L, 59L, 49L, 92L, 51L, 34L, 21L, 20L, 29L, 
28L, 22L, 32L, 30L, 86L, 0L, 15L, 7L, 4L, 0L, 0L, 0L, 11L, 3L, 
0L, 1L, 3L, 1L, 0L, 72L, 62L, 98L, 55L, 26L, 39L, 28L, 81L, 20L, 
52L, 12L, 48L, 24L, 40L, 30L, 5L, 6L, 40L, 37L, 33L, 26L, 17L, 
14L, 39L, 27L, 28L, 67L, 0L, 0L, 0L, 0L, 0L, 10L, 12L, 14L, 7L, 
0L, 2L, 39L, 67L, 74L, 28L, 23L, 57L, 34L, 36L, 36L, 37L, 46L, 
43L, 73L, 65L, 31L, 64L, 128L, 17L, 3L, 12L, 17L, 0L, 9L, 7L, 
0L, 17L, 0L), ACTIVITY_Z = c(62L, 95L, 88L, 82L, 90L, 80L, 70L, 
68L, 71L, 98L, 97L, 19L, 116L, 140L, 55L, 24L, 22L, 7L, 108L, 
107L, 82L, 85L, 36L, 55L, 91L, 75L, 69L, 60L, 64L, 76L, 81L, 
103L, 88L, 74L, 87L, 99L, 111L, 83L, 70L, 117L, 54L, 60L, 50L, 
96L, 83L, 77L, 73L, 62L, 67L, 55L, 117L, 174L, 136L, 5L, 1L, 
27L, 44L, 5L, 7L, 0L, 32L, 54L, 51L, 41L, 87L, 171L, 93L, 120L, 
22L, 117L, 120L, 124L, 148L, 110L, 65L, 136L, 74L, 102L, 81L, 
27L, 32L, 91L, 84L, 51L, 177L, 60L, 26L, 76L, 52L, 11L, 29L, 
14L, 3L, 33L, 14L, 97L, 81L, 39L, 23L, 7L, 14L, 0L, 2L, 0L, 4L, 
1L, 8L, 0L, 0L, 0L, 0L, 0L, 1L, 29L, 67L, 60L, 13L, 44L, 69L, 
70L, 75L, 115L, 69L, 78L, 62L, 126L, 79L, 53L, 48L, 45L, 51L, 
50L, 53L, 72L, 77L, 104L, 0L, 31L, 14L, 5L, 0L, 0L, 8L, 24L, 
7L, 2L, 7L, 5L, 3L, 3L, 85L, 77L, 114L, 76L, 45L, 59L, 48L, 95L, 
42L, 70L, 21L, 60L, 40L, 52L, 40L, 13L, 10L, 51L, 51L, 43L, 37L, 
29L, 22L, 52L, 47L, 40L, 86L, 5L, 0L, 0L, 0L, 0L, 20L, 25L, 29L, 
14L, 0L, 5L, 57L, 84L, 99L, 48L, 39L, 78L, 55L, 54L, 57L, 58L, 
66L, 67L, 84L, 79L, 50L, 79L, 141L, 35L, 6L, 27L, 25L, 5L, 17L, 
15L, 0L, 31L, 0L), classification = c("Feeding", "Standing", 
"Standing", "Standing", "Standing", "Standing", "Feeding", "Feeding", 
"Feeding", "Standing", "Standing", "Foraging", "Standing", "Standing", 
"Feeding", "Foraging", "Foraging", "Foraging", "Standing", "Standing", 
"Standing", "Standing", "Feeding", "Feeding", "Standing", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Standing", "Standing", 
"Standing", "Feeding", "Standing", "Standing", "Standing", "Standing", 
"Feeding", "Standing", "Feeding", "Feeding", "Feeding", "Standing", 
"Standing", "Feeding", "Feeding", "Feeding", "Feeding", "Feeding", 
"Standing", "Standing", "Standing", "Foraging", "Foraging", "Foraging", 
"Feeding", "Foraging", "Foraging", "Foraging", "Foraging", "Feeding", 
"Feeding", "Feeding", "Standing", "Standing", "Standing", "Standing", 
"Foraging", "Standing", "Standing", "Standing", "Standing", "Standing", 
"Feeding", "Standing", "Feeding", "Standing", "Standing", "Foraging", 
"Foraging", "Standing", "Feeding", "Feeding", "Standing", "Feeding", 
"Foraging", "Feeding", "Feeding", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Standing", "Standing", "Feeding", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Feeding", "Feeding", 
"Foraging", "Feeding", "Feeding", "Feeding", "Feeding", "Standing", 
"Feeding", "Feeding", "Feeding", "Standing", "Standing", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Feeding", "Feeding", 
"Standing", "Standing", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Standing", "Feeding", 
"Standing", "Feeding", "Feeding", "Feeding", "Feeding", "Standing", 
"Feeding", "Feeding", "Foraging", "Feeding", "Feeding", "Feeding", 
"Feeding", "Foraging", "Foraging", "Feeding", "Feeding", "Feeding", 
"Feeding", "Foraging", "Foraging", "Feeding", "Feeding", "Feeding", 
"Standing", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Feeding", "Standing", "Standing", "Feeding", "Feeding", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Feeding", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Standing", "Feeding", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging")), row.names = c(NA, -215L
), class = c("data.table", "data.frame"), .internal.selfref = <pointer: 0x0000000002531ef0>)
  

Любой ввод действительно приветствуется!

Комментарии:

1. Ну, разве вам не нужно predict сначала ?!

2. Вам нужно будет сравнить свои прогнозы с вашими фактическими классификациями, чтобы получить матрицу путаницы. Кажется, что ваши данные имеют только классификацию, а не прогноз.

3. использование savePredictions = "final" в trainControl this вернет прогнозы для прогнозов knnFit ожидания, чтобы вы могли рассчитать любую метрику по вашему желанию.

Ответ №1:

Вот воспроизводимый пример:

 library(caret)
train_set<-createDataPartition(iris$Species,p=0.8,list=FALSE)
valid_set<-iris[-train_set,]
train_set<-iris[train_set,]
ctrl<-trainControl(method="cv",number=5)
set.seed(233)
mk<-train(Species~.,data=train_set,method="knn",trControl = ctrl,metric="Accuracy")
  

Получение матрицы путаницы. В идеале, лучше сравнить ваше обучение со predict значениями ed из набора тестов или проверки.

Редактировать: чтобы получить таблицу, просто выполните:

 confusionMatrix(mk)["table"]
$table
            Reference
Prediction       setosa versicolor  virginica
  setosa     33.3333333  0.0000000  0.0000000
  versicolor  0.0000000 32.5000000  2.5000000
  virginica   0.0000000  0.8333333 30.8333333
  

Оригинал

  confusionMatrix(mk)
  

Результат:

 Cross-Validated (5 fold) Confusion Matrix 

(entries are percentual average cell counts across resamples)

            Reference
Prediction   setosa versicolor virginica
  setosa       33.3        0.0       0.0
  versicolor    0.0       31.7       1.7
  virginica     0.0        1.7      31.7

 Accuracy (average) : 0.9667
  

Комментарии:

1. Спасибо. Это то, что я искал. Однако, как упоминалось в моем первоначальном сообщении, я буду использовать матрицу путаницы для вычисления показателей производительности. Как я могу извлечь значения из предоставленной вами матрицы путаницы? Я пытаюсь использовать: > confMat<-confusionMatrix(knnFit) > confMat<-as.table(confMat) Error in as.table.default(confMat) : cannot coerce to a table но, как вы можете видеть, я получаю сообщение об ошибке.