#r #machine-learning #decision-tree
#r #машинное обучение #дерево решений
Вопрос:
Я запустил этот пример из rpart-manpage
tree <- rpart(Species~., data = iris)
plot(tree,margin=0.1)
text(tree)
Теперь я хочу изменить это для другого набора данных
digitstrainURL <- "http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tra"
digitsTestURL <- "http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tes"
digitstrain <- read.table(digitstrainURL, sep=",",
col.names=c("i1","i2","i3","i4","i5","i6","i7","i8","i9","i10","i11","i12","i13","i14","i15","i16", "Class"))
digitstest <- read.table(digitsTestURL, sep=",",
col.names=c("i1","i2","i3","i4","i5","i6","i7","i8","i9","i10","i11","i12","i13","i14","i15","i16", "Class"))
tree <- rpart(Class~., data = digitstrain)
plot(tree,margin=0.1)
text(tree)
Набор данных содержит данные для написанных от руки цифр, а «Класс» содержит цифру 0-9
Но когда я строю дерево, в результате я получаю числа с плавающей запятой, есть идеи, что означают эти цифры? Я бы предпочел иметь 0-9 в качестве текста для листов.
Ответ №1:
Вы пытаетесь подогнать дерево классификации, но ваши данные представляют собой целые числа, а не факторы.
Функция rpart
попытается угадать, какой метод использовать, и в вашем случае делает неверное предположение. Таким образом, ваш код соответствует дереву, основанному на method="anova"
, в то время как вы хотите использовать method="class"
.
Попробуйте это:
tree <- rpart(Class~., data = digitstrain, method="class")
plot(tree,margin=0.1)
text(tree, cex=0.7)
Чтобы проверить точность вашей модели, вы можете использовать predict
для получения прогнозируемых значений, а затем создать матрицу путаницы:
confusion <- data.frame(
class=factor(digitstest$Class),
predict=predict(tree, digitstest, type="class")
)
with(confusion, table(class, predict))
predict
class 0 1 2 3 4 5 6 7 8 9
0 311 1 0 0 0 0 0 7 42 2
1 0 139 186 4 0 0 0 1 10 24
2 0 0 320 14 2 3 0 7 15 3
3 0 6 0 309 1 3 0 17 0 0
4 0 1 0 5 300 0 0 0 0 58
5 0 0 0 74 0 177 0 1 14 69
6 5 0 3 9 12 0 264 11 5 27
7 2 9 11 13 0 10 0 290 0 29
8 60 0 0 0 0 32 0 21 220 3
9 1 44 0 9 20 0 0 8 0 254
Обратите внимание, что прогнозирование с использованием одного дерева невелико. Очень простой способ улучшить прогнозирование — использовать случайный лес, состоящий из множества деревьев, снабженных случайными подмножествами ваших обучающих данных:
library(randomForest)
fst <- randomForest(factor(Class)~., data = digitstrain, method="class")
Обратите внимание, что лес дает гораздо лучшие результаты прогнозирования:
confusion <- data.frame(
class=factor(digitstest$Class),
predict=predict(fst, digitstest, type="class")
)
with(confusion, table(class, predict))
predict
class 0 1 2 3 4 5 6 7 8 9
0 347 0 0 0 0 0 0 0 16 0
1 0 333 28 1 1 0 0 1 0 0
2 0 5 359 0 0 0 0 0 0 0
3 0 4 0 331 0 0 0 0 0 1
4 0 0 0 0 362 1 0 0 0 1
5 0 0 0 8 0 316 0 0 0 11
6 1 0 0 0 0 0 335 0 0 0
7 0 26 2 0 0 0 0 328 0 8
8 0 0 0 0 0 0 0 0 336 0
9 0 2 0 0 0 0 0 2 1 331
Комментарии:
1. Большое спасибо! разве результат не должен соответствовать всем обучающим данным? Я выбрал некоторые данные и проверил их самостоятельно, но это не дает мне правильного результата. Есть ли возможность заполнить дерево данными и получить результат?
2. @user2071938 Одно дерево не гарантирует хорошую модель, но модель случайного леса в целом будет работать лучше. Я расширил свой ответ.
Ответ №2:
Это происходит потому, что ваш столбец класса является числовым. Преобразуйте его в фактор, затем попробуйте…
digitstrain$Class = as.factor(digitstrain$Class)
tree <- rpart(Class~., data = digitstrain)
plot(tree,margin=0.1)
text(tree)
Результатом будет