Деревья решений с R

#r #machine-learning #decision-tree

#r #машинное обучение #дерево решений

Вопрос:

Я запустил этот пример из rpart-manpage

 tree <- rpart(Species~., data = iris)
plot(tree,margin=0.1)
text(tree)
  

Теперь я хочу изменить это для другого набора данных

 digitstrainURL <- "http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tra"
digitsTestURL <- "http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tes"
digitstrain <- read.table(digitstrainURL, sep=",",
                          col.names=c("i1","i2","i3","i4","i5","i6","i7","i8","i9","i10","i11","i12","i13","i14","i15","i16", "Class"))
digitstest <- read.table(digitsTestURL, sep=",",
col.names=c("i1","i2","i3","i4","i5","i6","i7","i8","i9","i10","i11","i12","i13","i14","i15","i16", "Class"))

tree <- rpart(Class~., data = digitstrain)
plot(tree,margin=0.1)
text(tree)
  

Набор данных содержит данные для написанных от руки цифр, а «Класс» содержит цифру 0-9
Но когда я строю дерево, в результате я получаю числа с плавающей запятой, есть идеи, что означают эти цифры? Я бы предпочел иметь 0-9 в качестве текста для листов.

Ответ №1:

Вы пытаетесь подогнать дерево классификации, но ваши данные представляют собой целые числа, а не факторы.

Функция rpart попытается угадать, какой метод использовать, и в вашем случае делает неверное предположение. Таким образом, ваш код соответствует дереву, основанному на method="anova" , в то время как вы хотите использовать method="class" .

Попробуйте это:

 tree <- rpart(Class~., data = digitstrain, method="class")
plot(tree,margin=0.1)
text(tree, cex=0.7)
  

введите описание изображения здесь

Чтобы проверить точность вашей модели, вы можете использовать predict для получения прогнозируемых значений, а затем создать матрицу путаницы:

 confusion <- data.frame(
  class=factor(digitstest$Class), 
  predict=predict(tree, digitstest, type="class")
  )
with(confusion, table(class, predict))

     predict
class   0   1   2   3   4   5   6   7   8   9
    0 311   1   0   0   0   0   0   7  42   2
    1   0 139 186   4   0   0   0   1  10  24
    2   0   0 320  14   2   3   0   7  15   3
    3   0   6   0 309   1   3   0  17   0   0
    4   0   1   0   5 300   0   0   0   0  58
    5   0   0   0  74   0 177   0   1  14  69
    6   5   0   3   9  12   0 264  11   5  27
    7   2   9  11  13   0  10   0 290   0  29
    8  60   0   0   0   0  32   0  21 220   3
    9   1  44   0   9  20   0   0   8   0 254
  

Обратите внимание, что прогнозирование с использованием одного дерева невелико. Очень простой способ улучшить прогнозирование — использовать случайный лес, состоящий из множества деревьев, снабженных случайными подмножествами ваших обучающих данных:

 library(randomForest)

fst <- randomForest(factor(Class)~., data = digitstrain, method="class")
  

Обратите внимание, что лес дает гораздо лучшие результаты прогнозирования:

 confusion <- data.frame(
  class=factor(digitstest$Class), 
  predict=predict(fst, digitstest, type="class")
  )
with(confusion, table(class, predict))

     predict
class   0   1   2   3   4   5   6   7   8   9
    0 347   0   0   0   0   0   0   0  16   0
    1   0 333  28   1   1   0   0   1   0   0
    2   0   5 359   0   0   0   0   0   0   0
    3   0   4   0 331   0   0   0   0   0   1
    4   0   0   0   0 362   1   0   0   0   1
    5   0   0   0   8   0 316   0   0   0  11
    6   1   0   0   0   0   0 335   0   0   0
    7   0  26   2   0   0   0   0 328   0   8
    8   0   0   0   0   0   0   0   0 336   0
    9   0   2   0   0   0   0   0   2   1 331
  

Комментарии:

1. Большое спасибо! разве результат не должен соответствовать всем обучающим данным? Я выбрал некоторые данные и проверил их самостоятельно, но это не дает мне правильного результата. Есть ли возможность заполнить дерево данными и получить результат?

2. @user2071938 Одно дерево не гарантирует хорошую модель, но модель случайного леса в целом будет работать лучше. Я расширил свой ответ.

Ответ №2:

Это происходит потому, что ваш столбец класса является числовым. Преобразуйте его в фактор, затем попробуйте…

 digitstrain$Class = as.factor(digitstrain$Class)
tree <- rpart(Class~., data = digitstrain)
plot(tree,margin=0.1)
text(tree)
  

Результатом будет

введите описание изображения здесь