1、创建任务
library(pacman)
p_load(mlr3)
str(iris)
## 'data.frame': 150 obs. of 5 variables:
## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
## $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
# 创建任务
task <- TaskClassif$new(id = "iris", backend = iris, target = "Species")
2、选择学习器
# 机器学习算法,使用rpart
lrner <- lrn("classif.rpart", cp = 0.1, minsplit = 10)
3、拆分训练集和测试集
set.seed(123)
# 按照8:2拆分
dtrain <- sample(task$nrow, 0.8 * task$nrow)
dtest <- setdiff(seq_len(task$nrow), dtrain)
4、训练模型
lrner$train(task, row_ids = dtrain)
# 查看训练好的模型
lrner$model
## n= 120
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 120 75 virginica (0.33333333 0.29166667 0.37500000)
## 2) Petal.Length< 2.45 40 0 setosa (1.00000000 0.00000000 0.00000000) *
## 3) Petal.Length>=2.45 80 35 virginica (0.00000000 0.43750000 0.56250000)
## 6) Petal.Length< 4.75 32 1 versicolor (0.00000000 0.96875000 0.03125000) *
## 7) Petal.Length>=4.75 48 4 virginica (0.00000000 0.08333333 0.91666667) *
5、预测
pred <- lrner$predict(task, row_ids = dtest)
pred$response
## [1] setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa
## [11] versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor
## [21] virginica versicolor virginica versicolor versicolor virginica virginica virginica virginica virginica
## Levels: setosa versicolor virginica
6、模型评估
# 混淆矩阵
pred$confusion
## truth
## response setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 13 0
## virginica 0 2 5
# 准确率
pred$score(msr("classif.acc"))
## classif.acc
## 0.9333333
7、交叉验证
resampling <- rsmp("cv", folds = 10L)
rr <- resample(task = task,
learner = lrner,
resampling = resampling)
rr$aggregate(measures = msr("classif.acc"))
## classif.acc
## 0.9266667