我們第一個介紹的是DecisionTree。先請同學安裝套件rpart。
check_then_install("rpart", "4.1.10")
接著,請載入套件rpart
library(rpart)
在摸索一個套件時,我們可以找找看套件作者有沒有撰寫vignette。請同學輸入:vignette(package="rpart")
vignette(package = "rpart")
由跳出的視窗,我們可以看到一個名為:"longintro"的文件名稱,是一份介紹rpart的文件。請輸入vignette("longintro",package="rpart")打開這份文件。
有興趣的同學可以閱讀這份文件的前半段。我們則直接用範例來解說rpart的功能。
請同學先輸入data(stagec)載入一個關於C期前列腺癌的研究數據。這比數據中,記錄著146位病患的資訊。
data(stagec)
請同學輸入:cfit<-rpart(pgstat~age+eet+g2,data=stagec,method="class")。
cfit <- rpart(pgstat ~ age + eet + g2, data = stagec, method = "class")
這裡的函數rpart就是用於建立decisiontree的函數。請同學打開rpart的說明頁面。
?rpart
根據rpart的說明文件,我們剛剛輸入的:cfit<-rpart(pgstat~age+eet+g2,data=stagec,method="class")中的pgstat~age+eet+g2是對應到rpart函數的哪一個參數呢?
formula
上述輸入的formula參數:pgstat~age+eet+g2,描述的是在建構decisiontree時,變數之間的關係。pgstat是要被預測、被分類的變數名稱,age、eet和g2則是用來對pgstat做預測的變數。
根據rpart的說明文件,我們剛剛輸入的:cfit<-rpart(pgstat~age+eet+g2,data=stagec,method="class")中的stagec是對應到rpart函數的哪一個參數呢?
data
接著,請列出stagec的欄位名稱。
colnames(stagec)
我們可以看到,剛剛formula中的變數名稱,都在stagec之中。
rpart這個函數有許多功能,使用者可以在method的參數指定要使用的功能。請同學參考rpart的說明文件中,關於method參數的說明。請問下列哪一個選項「不是」rpart的method參數的有效選項?
regression
在關於method的說明文件中,仔細地解釋了rpart是如何依照formula中選擇的變數形態來智慧的選擇預設的method。請同學查詢stagec的pgstat欄位的形態為何。
class(stagec$pgstat)
依照rpart的說明文件和stagec$pgstat的型態,請問如果我們沒有指定method的話,rpart會用哪一種method參數來運作?
anova
接著,請輸入cfit來看看rpart的結果。
cfit
R會把從資料中學到的decisiontree顯示到console中。前段的文字說明了每一行的資訊依序是:node),split,n,loss,yval,(yprob)而且最後標記有星號的就是decisiontree的leafnode。舉例來說,1)root146540(0.63013700.3698630)代表這是第一個node,他的切割規則是root,有146個點,loss是54,deviance是0。請問同學,第二個node的loss是什麼?
18
這裡的loss代表的是錯誤的label的個數,俗稱0/1loss。在第一個node,也就是root之中,cfit對stagec$pgstat的預測是0。請同學計算stagec$pgstat中非0的病患總數。看看是不是和第一行,1)root中顯示的loss相同。
sum(stagec$pgstat != 0)
另外同學應該有注意到,node的編號並不是連續的。這是因為,每個編號為x的node,他的分支一定是編號2x和2x+1。請問同學,編號7的node是編號多少的node的分支?
3
接著,讓我們畫出cfit。這需要兩個指令,所以請同學先輸入:plot(cfit)
plot(cfit)
再請同學輸入text(cfit)
text(cfit)
我們可以發現,圖的上下維有一點被切掉。這可以透過par函數的mar參數做調整。但是其實已經有人發現這件事情,並且寫了一個叫做rpart.plot的套件。請同學安裝這個套件
check_then_install("rpart.plot", "1.5.3")
接著,請載入rpart.plot套件。
library(rpart.plot)
我們直接輸入:rpart.plot(cfit)來看看畫圖的結果。
rpart.plot(cfit)
rpart.plot套件對於rpart的圖片輸出做過調整,所以就不會出現圖形被截掉的狀態。
接著讓我們來探索rpart是如何產生cfit這棵樹。
rpart其實有非常多的參數,並且各類參數的細節分佈在rpart的參數parms和control中。
在我們剛剛打開的vignette的Chapter3.1,作者說明了如何建構一個decisiontree。裡面解釋了何謂prior、loss和splittingindex。
rpart的參數parms裡面可以設定和method相關的參數。
請問同學,根據rpart的說明文件(請參閱Arguments底下的parms),當method為class時(classificationsplitting),預設的prior為何?1)每種類別都相等的機率;2)和資料中各類別出現的頻率成正比的機率
2
請問同學,根據rpart的說明文件(請參閱Arguments底下的parms),當method為class時(classificationsplitting),預設的splittingindex為何?
gini
rpart把和method無關的參數放到control底下,並且提供一個輔助函數rpart.control來協助使用者在實作時也限制了每個split時,該node的個數限制。請同學輸入?rpart.control來看一下這些控制有哪些參數。
?rpart.control
接著,我們來重現cfit的第一層結果。請同學閱讀檔案中的程式碼與註解後,輸入submit()。
# R 中的型態很重要。類別的數據,調整成factor之後做運算會方便很多
y <- factor(stagec$pgstat)
# n 是各種類別出現的次數
n <- table(y)
# 預設的prior 是各種類別出現的比率
prior <- n / length(y)
#'@title 這是Vignette中的P 函數的實作
#'
#'@param x factor vector.
#'@return numeric value. 資料點屬於x 的機率
P <- function(x) {
x.tb <- table(x)
sum(pi * (x.tb / n))
}
#'@title 這是gini index的計算
#'
#'@param p numeric value. 是某個類別的機率
#'@return numeric value. 該類別的gini index
gini <- function(p) p * (1 - p)
#'@title 這是使用gini index做切割準則時,I 函數的實作
#'@param x factor vector.
#'@return numeric value. x 的impurity
I <- function(x) {
x.tb <- table(x)
# 各種類別的機率
category.prob <- x.tb / length(x)
#' R 也是一種函數式語言,而sapply等函數能夠很方便的取代for 迴圈
#' 這個寫法等價於:
#' for(p in category.prob) gini(p)
#' 但是自動把輸出結果排列成一個向量
category.gini <- sapply(category.prob, gini)
sum(category.gini)
}
PI <- function(x) P(x) * I(x)
#'@title 給定一個切點之後,計算impurity降低的幅度
impurity_variation_after_cut <- function(cut) {
origin.impurity <- I(y) * P(y)
# split 會依照第二個參數的值,將第一個參數分成若干個向量。
# split 的結果是一個list,而且每一個list element對應到第二個參數的一種類別
group <- split(y, stagec$age < cut)
# group 是一個長度為二的list
# 第一個element是所有stage$age < cut為FALSE 的病患對應的pgstat
# 第二個element是所有stage$age < cut為TRUE 的病患對應的pgstat
# 對各種切割後的node計算PI後加總
splitted.impurity <- sum(sapply(group, PI))
origin.impurity - splitted.impurity
}
# 列舉所有可能的切點
eval.x <- seq(min(stagec$age) - 0.5, max(stagec$age) + 0.5, by = 1)
# 算出每個切點,對應的impurity 的改善量
index <- sapply(eval.x, impurity_variation_after_cut)
請問同學,讓impurity改善最大的切點,是第幾個呢?同學可以用which.max函數作答。
which.max(index)
對應的切點的值是多少呢?請利用上一題的答案。
eval.x[which.max(index)]
上一題的答案和cfit的結果不一致。從前面cfit的輸出可以看到,rpart的第一個切點是age>=58.5。這其實是受到control這個參數的影響,所以rpart不會切割出太小(包含太少資料點)的node。請同學輸入:rpart(pgstat~age,data=stagec,method="class",control=rpart.control(minsplit=1))
rpart(pgstat ~ age, data = stagec, method = "class", control = rpart.control(minsplit = 1))
同學是不是看到第一個切點變成我們之前算出來的50.5了?
rpart在做分類時,是利用公式去計算各種切點的impurity的改善。而這些切點的選擇也是有限制的(透過rpart.control)。使用者可以透過control=rpart.control(minsplit=1)來對這些限制條件做修正。
Impurity的計算則可以透過parms的設定來調整。請同學閱讀檔案中的程式碼與註解後,輸入submit()。
# R 中的型態很重要。類別的數據,調整成factor之後做運算會方便很多
y <- factor(stagec$pgstat)
# n 是各種類別出現的次數
n <- table(y)
# 預設的prior 是各種類別出現的比率
prior <- n / length(y)
#'@title 這是Vignette中的P 函數的實作
#'
#'@param x factor vector.
#'@return numeric value. 資料點屬於x 的機率
P <- function(x) {
x.tb <- table(x)
sum(pi * (x.tb / n))
}
#'@title 這是information index的計算
#'
#'@param p numeric value. 是某個類別的機率
#'@return numeric value. 該類別的information index
information <- function(p) {
if (p == 0) 0 else - p * log(p)
}
#'@title 這是使用information index做切割準則時,I 函數的實作
#'@param x factor vector.
#'@return numeric value. x 的impurity
I <- function(x) {
x.tb <- table(x)
# 各種類別的機率
category.prob <- x.tb / length(x)
#' R 也是一種函數式語言,而sapply等函數能夠很方便的取代for 迴圈
#' 這個寫法等價於:
#' for(p in category.prob) gini(p)
#' 但是自動把輸出結果排列成一個向量
category.information <- sapply(category.prob, information)
sum(category.information)
}
PI <- function(x) P(x) * I(x)
#'@title 給定一個切點之後,計算impurity降低的幅度
impurity_variation_after_cut <- function(cut) {
origin.impurity <- I(y) * P(y)
# split 會依照第二個參數的值,將第一個參數分成若干個向量。
# split 的結果是一個list,而且每一個list element對應到第二個參數的一種類別
group <- split(y, stagec$age < cut)
# group 是一個長度為二的list
# 第一個element是所有stage$age < cut為FALSE 的病患對應的pgstat
# 第二個element是所有stage$age < cut為TRUE 的病患對應的pgstat
# 對各種切割後的node計算PI後加總
splitted.impurity <- sum(sapply(group, PI))
origin.impurity - splitted.impurity
}
# 列舉所有可能的切點
eval.x <- seq(min(stagec$age) - 0.5, max(stagec$age) + 0.5, by = 1)
# 算出每個切點,對應的impurity 的改善量
index <- sapply(eval.x, impurity_variation_after_cut)
在改成用informationindex後,對應的切點的值是多少呢?
eval.x[which.max(index)]
請同學輸入:rpart(pgstat~age,data=stagec,method="class",parms=list(split="information"),control=rpart.control(minsplit=1))
rpart(pgstat ~ age, data = stagec, method = "class", parms = list(split = "information"), control = rpart.control(minsplit=1))
可以看到差不多的結果。
rpart也可以讓我們自己定義分割的邏輯。這題會打開rpart套件提供的範例給同學參考。有興趣的同學可以仔細研究。讀完之後請輸入submit()
# The following script is based on `mystate` dataset.
mystate <- data.frame(state.x77, region=state.region)
colnames(mystate) <- tolower(colnames(mystate))
# The 'evaluation' function. Called once per node.
# Produce a label (1 or more elements long) for labeling each node,
# and a deviance. The latter is
# - of length 1
# - equal to 0 if the node is "pure" in some sense (unsplittable)
# - does not need to be a deviance: any measure that gets larger
# as the node is less acceptable is fine.
# - the measure underlies cost-complexity pruning, however
temp1 <- function(y, wt, parms) {
wmean <- sum(y*wt)/sum(wt)
rss <- sum(wt*(y-wmean)^2)
list(label= wmean, deviance=rss)
}
# The split function, where most of the work occurs.
# Called once per split variable per node.
# If continuous=T
# The actual x variable is ordered
# y is supplied in the sort order of x, with no missings,
# return two vectors of length (n-1):
# goodness = goodness of the split, larger numbers are better.
# 0 = couldn't find any worthwhile split
# the ith value of goodness evaluates splitting obs 1:i vs (i+1):n
# direction= -1 = send "y< cutpoint" to the left side of the tree
# 1 = send "y< cutpoint" to the right
# this is not a big deal, but making larger "mean y's" move towards
# the right of the tree, as we do here, seems to make it easier to
# read
# If continuos=F, x is a set of integers defining the groups for an
# unordered predictor. In this case:
# direction = a vector of length m= "# groups". It asserts that the
# best split can be found by lining the groups up in this order
# and going from left to right, so that only m-1 splits need to
# be evaluated rather than 2^(m-1)
# goodness = m-1 values, as before.
#
# The reason for returning a vector of goodness is that the C routine
# enforces the "minbucket" constraint. It selects the best return value
# that is not too close to an edge.
temp2 <- function(y, wt, x, parms, continuous) {
# Center y
n <- length(y)
y <- y- sum(y*wt)/sum(wt)
if (continuous) {
# continuous x variable
temp <- cumsum(y*wt)[-n]
left.wt <- cumsum(wt)[-n]
right.wt <- sum(wt) - left.wt
lmean <- temp/left.wt
rmean <- -temp/right.wt
goodness <- (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2)
list(goodness= goodness, direction=sign(lmean))
}
else {
# Categorical X variable
ux <- sort(unique(x))
wtsum <- tapply(wt, x, sum)
ysum <- tapply(y*wt, x, sum)
means <- ysum/wtsum
# For anova splits, we can order the categories by their means
# then use the same code as for a non-categorical
ord <- order(means)
n <- length(ord)
temp <- cumsum(ysum[ord])[-n]
left.wt <- cumsum(wtsum[ord])[-n]
right.wt <- sum(wt) - left.wt
lmean <- temp/left.wt
rmean <- -temp/right.wt
list(goodness= (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2),
direction = ux[ord])
}
}
# The init function:
# fix up y to deal with offsets
# return a dummy parms list
# numresp is the number of values produced by the eval routine's "label"
# numy is the number of columns for y
# summary is a function used to print one line in summary.rpart
# In general, this function would also check for bad data, see rpart.poisson
# for instace.
temp3 <- function(y, offset, parms, wt) {
if (!is.null(offset)) y <- y-offset
list(y=y, parms=0, numresp=1, numy=1,
summary= function(yval, dev, wt, ylevel, digits ) {
paste(" mean=", format(signif(yval, digits)),
", MSE=" , format(signif(dev/wt, digits)),
sep='')
})
}
alist <- list(eval=temp1, split=temp2, init=temp3)
fit1 <- rpart(income ~population +illiteracy + murder + hs.grad + region,
mystate, control=rpart.control(minsplit=10, xval=0),
method=alist)
我們可以利用predict函數來使用學習好的cfit函數做預測。舉例來說,predict(cfit,stagec)就可以用我們學到的模型回頭預測stagec的pgstat的值。請同學試試看。
predict(cfit, stagec)
如果同學想要對rpart套件提供的預測函數predict有更多的了解可以輸入:?predict.rpart注意歐,這裡我們使用的並不是?predict,因為這裡rpart套件採用了R的S3物件導向方法。由於class(cfit)的輸出是rpart,所以predict函數最後會呼叫predict.rpart來對cfit做處理,相關的說明文件也會放在?predict.rpart之中。
?predict.rpart
以上就是對rpart這個套件的介紹。
最後我們再次使用rpart來挑戰mlbench的Ionosphere資料
check_then_install("mlbench", "2.1.1")
library(mlbench)
# 方便起見,同學可以使用這個函數計算 Logarithmic Loss
logloss <- function(y, p, tol = 1e-4) {
# tol 的用途是避免對0取log所導致的數值問題
p[p < tol] <- tol
p[p > 1 - tol] <- 1-tol
-sum(y * log(p) + (1 - y) * log(1-p))
}
data(Ionosphere)
test.i <- c(4L, 6L, 9L, 13L, 14L, 22L, 31L, 33L, 50L, 52L, 61L, 63L, 68L,
79L, 91L, 99L, 119L, 135L, 154L, 155L, 160L, 162L, 166L, 194L,
200L, 219L, 233L, 236L, 237L, 242L, 244L, 248L, 250L, 257L, 261L,
276L, 278L, 283L, 292L, 310L, 312L, 315L, 319L, 323L, 325L, 327L,
335L, 337L, 338L, 344L)
df.test <- Ionosphere[test.i,-2] # remove V2
train.i <- setdiff(seq_len(nrow(Ionosphere)), test.i)
df.train <- Ionosphere[train.i,-2]
# 請利用rpart,從df.train上學出一個模型
# 該模型在df.test上的logloss需要小於12
answer_05 <- local({
NULL
# 請調整以下的程式碼
rpart(Class ~ ., data = df.train, control = rpart.control(minsplit=50))
})
stopifnot(class(answer_05) == c("rpart"))
if (interactive()) {
stopifnot(local({
p <- predict(answer_05, df.test)[,"good"]
logloss(df.test$Class == "good", p) < 12
}))
}
# 完成後,請存檔後回到console輸入`submit()`